SDK ocr 2.0 (#1006)

* add deploy runner

* fix text_det wrapper

* fix recog

* save

* add crnn support

* update with_padding

* add short scale aspect jitter

* update regression test

* torch2ts

* add test data

* resolve comments
pull/1040/head
AllentDan 2022-09-19 15:08:51 +08:00 committed by GitHub
parent e37bfda86a
commit 97e0d1228f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 489 additions and 101 deletions

View File

@ -4,5 +4,10 @@ codebase_config = dict(model_type='sdk')
backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True),
dict(type='PackTextDetInputs', meta_keys=['img_path', 'ori_shape'])
])

View File

@ -4,5 +4,8 @@ codebase_config = dict(model_type='sdk')
backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
])

View File

@ -49,11 +49,10 @@ class CTCConvertor : public MMOCR {
throw_exception(eInvalidArgument);
}
// CTCConverter
idx2char_.insert(begin(idx2char_), "<BLK>");
if (_cfg.value("with_unknown", false)) {
unknown_idx_ = static_cast<int>(idx2char_.size());
idx2char_.emplace_back("<UKN>");
if (_cfg.value("with_padding", false)) {
padding_idx_ = static_cast<int>(idx2char_.size());
idx2char_.emplace_back("<PAD>");
}
model_ = model;
@ -88,20 +87,19 @@ class CTCConvertor : public MMOCR {
return make_pointer(to_value(output));
}
static std::pair<vector<int>, vector<float> > Tensor2Idx(const float* data, int w, int c,
float valid_ratio) {
std::pair<vector<int>, vector<float> > Tensor2Idx(const float* data, int w, int c,
float valid_ratio) {
auto decode_len = std::min(w, static_cast<int>(std::ceil(w * valid_ratio)));
vector<int> indexes;
indexes.reserve(decode_len);
vector<float> scores;
scores.reserve(decode_len);
vector<float> prob(c);
int prev = blank_idx_;
int prev = padding_idx_;
for (int t = 0; t < decode_len; ++t, data += c) {
softmax(data, prob.data(), c);
vector<float> prob(data, data + c);
auto iter = max_element(begin(prob), end(prob));
auto index = static_cast<int>(iter - begin(prob));
if (index != blank_idx_ && index != prev) {
if (index != padding_idx_ && index != prev) {
indexes.push_back(index);
scores.push_back(*iter);
}
@ -123,19 +121,6 @@ class CTCConvertor : public MMOCR {
return text;
}
// TODO: move softmax & top-k into model
static void softmax(const float* src, float* dst, int n) {
auto max_val = *std::max_element(src, src + n);
float sum{};
for (int i = 0; i < n; ++i) {
dst[i] = std::exp(src[i] - max_val);
sum += dst[i];
}
for (int i = 0; i < n; ++i) {
dst[i] /= sum;
}
}
protected:
static vector<string> SplitLines(const string& s) {
std::istringstream is(s);
@ -166,7 +151,7 @@ class CTCConvertor : public MMOCR {
Model model_;
static constexpr const int blank_idx_{0};
int unknown_idx_{-1};
int padding_idx_{-1};
vector<string> idx2char_;
};

View File

@ -46,17 +46,12 @@ class DBHead : public MMOCR {
Result<Value> operator()(const Value& _data, const Value& _prob) const {
auto conf = _prob["output"].get<Tensor>();
if (!(conf.shape().size() == 4 && conf.data_type() == DataType::kFLOAT)) {
if (!(conf.shape().size() == 3 && conf.data_type() == DataType::kFLOAT)) {
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", conf.shape(),
(int)conf.data_type());
return Status(eNotSupported);
}
// drop batch dimension
conf.Squeeze(0);
conf = conf.Slice(0);
std::vector<std::vector<cv::Point>> contours;
std::vector<float> scores;
OUTCOME_TRY(impl_->Process(conf, mask_thr_, max_candidates_, contours, scores));

View File

@ -0,0 +1,143 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <set>
#include "mmdeploy/archive/json_archive.h"
#include "mmdeploy/archive/value_archive.h"
#include "mmdeploy/core/registry.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/preprocess/transform/resize.h"
#include "mmdeploy/preprocess/transform/transform.h"
#include "opencv2/imgproc.hpp"
#include "opencv_utils.h"
using namespace std;
namespace mmdeploy {
class RescaleToHeightImpl : public Module {
public:
explicit RescaleToHeightImpl(const Value& args) noexcept {
height_ = args.value("height", height_);
min_width_ = args.contains("min_width") && args["min_width"].is_number_integer()
? args["min_width"].get<int>()
: min_width_;
max_width_ = args.contains("max_width") && args["max_width"].is_number_integer()
? args["max_width"].get<int>()
: max_width_;
width_divisor_ = args.contains("width_divisor") && args["width_divisor"].is_number_integer()
? args["width_divisor"].get<int>()
: width_divisor_;
resize_type_ = args.contains("resize_type") && args["resize_type"].is_string()
? args["resize_type"].get<string>()
: resize_type_;
stream_ = args["context"]["stream"].get<Stream>();
}
~RescaleToHeightImpl() override = default;
Result<Value> Process(const Value& input) override {
MMDEPLOY_DEBUG("input: {}", input);
auto dst_height = height_;
auto dst_min_width = min_width_;
auto dst_max_width = max_width_;
std::vector<int> img_shape; // NHWC
from_value(input["img_shape"], img_shape);
std::vector<int> ori_shape; // NHWC
from_value(input["ori_shape"], ori_shape);
auto ori_height = ori_shape[1];
auto ori_width = ori_shape[2];
auto valid_ratio = 1.f;
Device host{"cpu"};
auto _img = input["img"].get<Tensor>();
OUTCOME_TRY(auto img, MakeAvailableOnDevice(_img, host, stream_));
stream_.Wait().value();
Tensor img_resize;
auto new_width = static_cast<int>(std::ceil(1.f * dst_height / ori_height * ori_width));
auto width_divisor = width_divisor_;
if (dst_min_width > 0) {
new_width = std::max(dst_min_width, new_width);
}
if (dst_max_width > 0) {
auto resize_width = std::min(dst_max_width, new_width);
}
if (new_width % width_divisor != 0) {
new_width = std::round(1.f * new_width / width_divisor) * width_divisor;
}
img_resize = ResizeImage(img, dst_height, new_width);
Value output = input;
output["img"] = img_resize;
output["resize_shape"] = to_value(img_resize.desc().shape);
output["pad_shape"] = output["resize_shape"];
output["ori_shape"] = input["ori_shape"];
output["scale"] = to_value(std::vector<int>({new_width, dst_height}));
output["valid_ratio"] = valid_ratio;
MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2));
return output;
}
Tensor ResizeImage(const Tensor& img, int dst_h, int dst_w) {
TensorDesc desc = img.desc();
assert(desc.shape.size() == 4);
assert(desc.data_type == DataType::kINT8);
int h = desc.shape[1];
int w = desc.shape[2];
int c = desc.shape[3];
assert(c == 3 || c == 1);
cv::Mat src_mat, dst_mat;
if (3 == c) { // rgb
src_mat = cv::Mat(h, w, CV_8UC3, const_cast<uint8_t*>(img.data<uint8_t>()));
} else { // gray
src_mat = cv::Mat(h, w, CV_8UC1, const_cast<uint8_t*>(img.data<uint8_t>()));
}
cv::Size size{dst_w, dst_h};
cv::resize(src_mat, dst_mat, size, cv::INTER_LINEAR);
return Tensor({desc.device, desc.data_type, {1, dst_h, dst_w, c}, ""},
{dst_mat.data, [mat = dst_mat](void* ptr) {}});
}
protected:
int height_{-1};
int min_width_{-1};
int max_width_{-1};
bool keep_aspect_ratio_{true};
int width_divisor_{1};
std::string resize_type_{"Resize"};
Stream stream_;
};
class RescaleToHeightImplCreator : public Creator<RescaleToHeightImpl> {
public:
const char* GetName() const override { return "cpu"; }
int GetVersion() const override { return 1; }
ReturnType Create(const Value& args) override {
return std::make_unique<RescaleToHeightImpl>(args);
}
};
MMDEPLOY_DEFINE_REGISTRY(RescaleToHeightImpl);
REGISTER_MODULE(RescaleToHeightImpl, RescaleToHeightImplCreator);
class RescaleToHeight : public Transform {
public:
explicit RescaleToHeight(const Value& args) : Transform(args) {
impl_ = Instantiate<RescaleToHeightImpl>("RescaleToHeight", args);
}
~RescaleToHeight() override = default;
Result<Value> Process(const Value& input) override { return impl_->Process(input); }
private:
std::unique_ptr<RescaleToHeightImpl> impl_;
static const std::string name_;
};
DECLARE_AND_REGISTER_MODULE(Transform, RescaleToHeight, 1);
} // namespace mmdeploy

View File

@ -0,0 +1,151 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <set>
#include "mmdeploy/archive/json_archive.h"
#include "mmdeploy/archive/value_archive.h"
#include "mmdeploy/core/registry.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/preprocess/transform/resize.h"
#include "mmdeploy/preprocess/transform/transform.h"
#include "opencv2/imgproc.hpp"
#include "opencv_utils.h"
using namespace std;
namespace mmdeploy {
class ShortScaleAspectJitterImpl : public Module {
public:
explicit ShortScaleAspectJitterImpl(const Value& args) noexcept {
short_size_ = args.contains("short_size") && args["short_size"].is_number_integer()
? args["short_size"].get<int>()
: short_size_;
if (args["ratio_range"].is_array() && args["ratio_range"].size() == 2) {
ratio_range_[0] = args["ratio_range"][0].get<float>();
ratio_range_[1] = args["ratio_range"][1].get<float>();
} else {
throw std::invalid_argument("'ratio_range' should be a float array of size 2");
}
if (args["aspect_ratio_range"].is_array() && args["aspect_ratio_range"].size() == 2) {
aspect_ratio_range_[0] = args["aspect_ratio_range"][0].get<float>();
aspect_ratio_range_[1] = args["aspect_ratio_range"][1].get<float>();
} else {
throw std::invalid_argument("'aspect_ratio_range' should be a float array of size 2");
}
scale_divisor_ = args.contains("scale_divisor") && args["scale_divisor"].is_number_integer()
? args["scale_divisor"].get<int>()
: scale_divisor_;
resize_type_ = args.contains("resize_type") && args["resize_type"].is_string()
? args["resize_type"].get<string>()
: resize_type_;
stream_ = args["context"]["stream"].get<Stream>();
}
~ShortScaleAspectJitterImpl() override = default;
Result<Value> Process(const Value& input) override {
MMDEPLOY_DEBUG("input: {}", input);
auto short_size = short_size_;
auto ratio_range = ratio_range_;
auto aspect_ratio_range = aspect_ratio_range_;
auto scale_divisor = scale_divisor_;
if (ratio_range[0] != 1.0 || ratio_range[1] != 1.0 || aspect_ratio_range[0] != 1.0 ||
aspect_ratio_range[1] != 1.0) {
MMDEPLOY_ERROR("unsupported `ratio_range` and `aspect_ratio_range`");
return Status(eNotSupported);
}
std::vector<int> img_shape; // NHWC
from_value(input["img_shape"], img_shape);
std::vector<int> ori_shape; // NHWC
from_value(input["ori_shape"], ori_shape);
auto ori_height = ori_shape[1];
auto ori_width = ori_shape[2];
Device host{"cpu"};
auto _img = input["img"].get<Tensor>();
OUTCOME_TRY(auto img, MakeAvailableOnDevice(_img, host, stream_));
stream_.Wait().value();
Tensor img_resize;
auto scale = static_cast<float>(1.0 * short_size / std::min(img_shape[1], img_shape[2]));
auto dst_height = static_cast<int>(std::round(scale * img_shape[1]));
auto dst_width = static_cast<int>(std::round(scale * img_shape[2]));
dst_height = static_cast<int>(std::ceil(1.0 * dst_height / scale_divisor) * scale_divisor);
dst_width = static_cast<int>(std::ceil(1.0 * dst_width / scale_divisor) * scale_divisor);
std::vector<float> scale_factor = {1.0 * dst_width / img_shape[2],
1.0 * dst_height / img_shape[1]};
img_resize = ResizeImage(img, dst_height, dst_width);
Value output = input;
output["img"] = img_resize;
output["resize_shape"] = to_value(img_resize.desc().shape);
output["scale"] = to_value(std::vector<int>({dst_width, dst_height}));
output["scale_factor"] = to_value(scale_factor);
MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2));
return output;
}
Tensor ResizeImage(const Tensor& img, int dst_h, int dst_w) {
TensorDesc desc = img.desc();
assert(desc.shape.size() == 4);
assert(desc.data_type == DataType::kINT8);
int h = desc.shape[1];
int w = desc.shape[2];
int c = desc.shape[3];
assert(c == 3 || c == 1);
cv::Mat src_mat, dst_mat;
if (3 == c) { // rgb
src_mat = cv::Mat(h, w, CV_8UC3, const_cast<uint8_t*>(img.data<uint8_t>()));
} else { // gray
src_mat = cv::Mat(h, w, CV_8UC1, const_cast<uint8_t*>(img.data<uint8_t>()));
}
cv::Size size{dst_w, dst_h};
cv::resize(src_mat, dst_mat, size, cv::INTER_LINEAR);
return Tensor({desc.device, desc.data_type, {1, dst_h, dst_w, c}, ""},
{dst_mat.data, [mat = dst_mat](void* ptr) {}});
}
protected:
int short_size_{736};
std::vector<float> ratio_range_{0.7, 1.3};
std::vector<float> aspect_ratio_range_{0.9, 1.1};
int scale_divisor_{1};
std::string resize_type_{"Resize"};
Stream stream_;
};
class ShortScaleAspectJitterImplCreator : public Creator<ShortScaleAspectJitterImpl> {
public:
const char* GetName() const override { return "cpu"; }
int GetVersion() const override { return 1; }
ReturnType Create(const Value& args) override {
return std::make_unique<ShortScaleAspectJitterImpl>(args);
}
};
MMDEPLOY_DEFINE_REGISTRY(ShortScaleAspectJitterImpl);
REGISTER_MODULE(ShortScaleAspectJitterImpl, ShortScaleAspectJitterImplCreator);
class ShortScaleAspectJitter : public Transform {
public:
explicit ShortScaleAspectJitter(const Value& args) : Transform(args) {
impl_ = Instantiate<ShortScaleAspectJitterImpl>("ShortScaleAspectJitter", args);
}
~ShortScaleAspectJitter() override = default;
Result<Value> Process(const Value& input) override { return impl_->Process(input); }
private:
std::unique_ptr<ShortScaleAspectJitterImpl> impl_;
static const std::string name_;
};
DECLARE_AND_REGISTER_MODULE(Transform, ShortScaleAspectJitter, 1);
} // namespace mmdeploy

View File

@ -41,12 +41,15 @@ def torch2torchscript(img: Any,
from mmdeploy.apis import build_task_processor
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
torch_model = task_processor.build_pytorch_model(model_checkpoint)
_, model_inputs = task_processor.create_input(
data, model_inputs = task_processor.create_input(
img,
input_shape,
data_preprocessor=getattr(torch_model, 'data_preprocessor', None))
if not isinstance(model_inputs, torch.Tensor):
if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1:
model_inputs = model_inputs[0]
data_samples = data['data_samples']
input_metas = {'data_samples': data_samples, 'mode': 'predict'}
context_info = dict(deploy_cfg=deploy_cfg)
backend = get_backend(deploy_cfg).value
output_prefix = osp.join(work_dir, osp.splitext(save_file)[0])
@ -56,6 +59,7 @@ def torch2torchscript(img: Any,
trace(
torch_model,
model_inputs,
input_metas=input_metas,
output_path_prefix=output_prefix,
backend=backend,
context_info=context_info,

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
@ -13,6 +14,7 @@ from ..core import PIPELINE_MANAGER
@PIPELINE_MANAGER.register_pipeline()
def trace(func: torch.nn.Module,
inputs: Union[torch.Tensor, Tuple],
input_metas: Optional[Dict] = None,
output_path_prefix: Optional[str] = None,
backend: Union[Backend, str] = 'default',
context_info: Dict = dict(),
@ -89,6 +91,7 @@ def trace(func: torch.nn.Module,
if isinstance(func, torch.nn.Module):
ir = IR.get(get_ir_config(deploy_cfg)['type'])
func = patch_model(func, cfg=deploy_cfg, backend=backend, ir=ir)
func.forward = partial(func.forward, **input_metas)
with RewriterContext(**context_info), torch.no_grad():
# for exporting models with weight that depends on inputs

View File

@ -59,17 +59,6 @@ def get_model_name_customs(deploy_cfg: mmengine.Config,
name = task_processor.get_model_name()
customs = []
if task == Task.TEXT_RECOGNITION:
from mmocr.models.builder import build_convertor
label_convertor = model_cfg.model.label_convertor
assert label_convertor is not None, 'model_cfg contains no label '
'convertor'
max_seq_len = 40 # default value in EncodeDecodeRecognizer of mmocr
label_convertor.update(max_seq_len=max_seq_len)
label_convertor = build_convertor(label_convertor)
fd = open(f'{work_dir}/dict_file.txt', mode='w+')
for item in label_convertor.idx2char:
fd.write(item + '\n')
fd.close()
customs.append('dict_file.txt')
return name, customs
@ -195,41 +184,27 @@ def get_preprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config):
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg', 'valid_ratio'
]
if 'transforms' in pipeline[-1]:
transforms = pipeline[-1]['transforms']
transforms.insert(0, pipeline[0])
for transform in transforms:
if transform['type'] == 'Resize':
transform['size'] = pipeline[-1].img_scale[::-1]
if 'img_scale' in transform:
transform.pop('img_scale')
else:
pipeline = [
item for item in pipeline if item['type'] != 'MultiScaleFilpAug'
]
transforms = pipeline
transforms = [
item for item in transforms if 'Random' not in item['type']
and 'RescaleToZeroOne' not in item['type']
item for item in pipeline
if 'Random' not in item['type'] and 'RescaleToZeroOne' not in
item['type'] and 'Annotation' not in item['type']
]
for i, transform in enumerate(transforms):
if 'keys' in transform and transform['keys'] == ['lq']:
transform['keys'] = ['img']
if 'key' in transform and transform['key'] == 'lq':
transform['key'] = 'img'
if transform['type'] == 'Resize':
transform['size'] = transform['scale']
del transform['scale']
if transform['type'] == 'ResizeEdge':
transform['type'] = 'Resize'
transform['keep_ratio'] = True
# now the sdk of class has bugs, because ResizeEdge not implement
# in sdk.
transform['size'] = (transform['scale'], transform['scale'])
if transform['type'] == 'PackTextDetInputs':
if transform['type'] in ('PackTextDetInputs', 'PackTextRecogInputs'):
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'PackDetInputs' or \
transform['type'] == 'PackClsInputs':
@ -237,6 +212,24 @@ def get_preprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config):
transform['type'] = 'Collect'
if 'keys' not in transform:
transform['keys'] = ['img']
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i]['scale']
data_preprocessor = model_cfg.model.data_preprocessor
transforms.insert(-1, dict(type='DefaultFormatBundle'))
transforms.insert(
-2,
dict(
type='Pad',
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
transforms.insert(
-3,
dict(
type='Normalize',
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
mean=data_preprocessor.get('mean', [0, 0, 0]),
std=data_preprocessor.get('std', [1, 1, 1])))
assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item type'\
' of pipeline should be LoadImageFromFile'
@ -249,13 +242,14 @@ def get_preprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config):
transforms=transforms)
def get_postprocess(deploy_cfg: mmengine.Config,
model_cfg: mmengine.Config) -> Dict:
def get_postprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
work_dir: str) -> Dict:
"""Get the post process information for pipeline.json.
Args:
deploy_cfg (mmengine.Config): Deploy config dict.
model_cfg (mmengine.Config): The model config dict.
work_dir (str): Work dir to save json files.
Return:
dict: Composed of the model name, type, module, input, params and
@ -264,20 +258,28 @@ def get_postprocess(deploy_cfg: mmengine.Config,
module = get_codebase(deploy_cfg).value
type = 'Task'
name = 'postprocess'
params = dict()
task = get_task_type(deploy_cfg)
task_processor = build_task_processor(
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu')
params = task_processor.get_postprocess()
post_processor = task_processor.get_postprocess()
# TODO remove after adding instance segmentation to task processor
if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in params:
if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in post_processor:
task = Task.INSTANCE_SEGMENTATION
component = task_map[task]['component']
if task != Task.SUPER_RESOLUTION and task != Task.SEGMENTATION:
if 'type' in params:
component = params.pop('type')
if task not in (Task.SUPER_RESOLUTION, Task.SEGMENTATION):
if 'type' in post_processor:
component = post_processor.pop('type')
output = ['post_output']
if task == Task.TEXT_RECOGNITION:
import shutil
shutil.copy(model_cfg.dictionary.dict_file,
f'{work_dir}/dict_file.txt')
with_padding = model_cfg.dictionary.get('with_padding', False)
params = dict(dict_file='dict_file.txt', with_padding=with_padding)
return dict(
type=type,
module=module,
@ -323,7 +325,7 @@ def get_pipeline(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
"""
preprocess = get_preprocess(deploy_cfg, model_cfg)
infer_info = get_inference_info(deploy_cfg, model_cfg, work_dir=work_dir)
postprocess = get_postprocess(deploy_cfg, model_cfg)
postprocess = get_postprocess(deploy_cfg, model_cfg, work_dir)
task = get_task_type(deploy_cfg)
input_names = preprocess['input']
output_names = postprocess['output']

View File

@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union
from typing import Dict, Optional, Sequence, Union
import torch
from mmengine.device import get_device
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.runner import Runner
from mmengine.registry import LOOPS
from mmengine.runner import Runner, TestLoop, autocast
class DeployTestRunner(Runner):
@ -62,3 +64,38 @@ class DeployTestRunner(Runner):
log_file = self._log_file
return super().build_logger(log_level, log_file, **kwargs)
@LOOPS.register_module()
class DeployTestLoop(TestLoop):
"""Loop for test. To skip data_preprocessor for SDK.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 testing. Defaults to
False.
"""
@torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
Args:
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
# skip data_preprocessor to avoid Normalize and Padding for SDK
outputs = self.runner.model._run_forward(
data_batch, mode='predict')
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)

View File

@ -6,6 +6,7 @@ import numpy as np
import torch
from mmengine import Config
from mmengine.dataset import pseudo_collate
from mmengine.dist import cast_data_device
from mmengine.model import BaseDataPreprocessor
from torch import nn
@ -161,6 +162,8 @@ class TextDetection(BaseTask):
data_ = test_pipeline(data_)
data.append(data_)
data = pseudo_collate(data)
data['inputs'] = cast_data_device(data['inputs'],
torch.device(self.device))
if data_preprocessor is not None:
data = data_preprocessor(data, False)
return data, data['inputs']

View File

@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Union
import mmengine
import torch
from mmengine.registry import Registry
from mmengine.structures import BaseDataElement
from mmengine.structures import BaseDataElement, InstanceData
from mmocr.structures import TextDetDataSample
from mmdeploy.codebase.base import BaseBackendModel
@ -144,12 +144,14 @@ class SDKEnd2EndModel(End2EndModel):
"""
boundaries = self.wrapper.invoke(inputs[0].permute(
[1, 2, 0]).contiguous().detach().cpu().numpy())
polygons = [boundary[:-1] for boundary in boundaries]
scores = torch.Tensor([boundary[-1] for boundary in boundaries])
boundaries = [list(x) for x in boundaries]
return [
dict(
boundary_result=boundaries,
filename=data_samples[0]['filename'])
]
pred_instances = InstanceData()
pred_instances.polygons = polygons
pred_instances.scores = scores
data_samples[0].pred_instances = pred_instances
return data_samples
def build_text_detection_model(model_files: Sequence[str],
@ -184,5 +186,6 @@ def build_text_detection_model(model_files: Sequence[str],
deploy_cfg=deploy_cfg,
model_cfg=model_cfg,
**kwargs))
backend_text_detector = backend_text_detector.to(device)
return backend_text_detector

View File

@ -6,6 +6,7 @@ import numpy as np
import torch
from mmengine import Config
from mmengine.dataset import pseudo_collate
from mmengine.dist import cast_data_device
from mmengine.model import BaseDataPreprocessor
from torch import nn
@ -167,6 +168,8 @@ class TextRecognition(BaseTask):
data.append(data_)
data = pseudo_collate(data)
data['inputs'] = cast_data_device(data['inputs'],
torch.device(self.device))
if data_preprocessor is not None:
data = data_preprocessor(data, False)
return data, data['inputs']
@ -223,7 +226,7 @@ class TextRecognition(BaseTask):
"""
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
preprocess = model_cfg.data.test.pipeline
preprocess = model_cfg.test_dataloader.dataset.pipeline
return preprocess
def get_postprocess(self) -> Dict:
@ -232,7 +235,9 @@ class TextRecognition(BaseTask):
Return:
dict: Composed of the postprocess information.
"""
postprocess = self.model_cfg.label_convertor
postprocess = self.model_cfg.model.decoder.postprocessor
if postprocess.type == 'CTCPostProcessor':
postprocess.type = 'CTCConvertor'
return postprocess
def get_model_name(self) -> str:

View File

@ -4,6 +4,7 @@ from typing import Sequence, Union
import mmengine
import torch
from mmengine.registry import Registry
from mmengine.structures import LabelData
from mmocr.utils.typing import RecSampleList
from mmdeploy.codebase.base import BaseBackendModel
@ -112,21 +113,26 @@ class End2EndModel(BaseBackendModel):
class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmocr format."""
def forward(self, img: Sequence[torch.Tensor],
img_metas: Sequence[Sequence[dict]], *args, **kwargs):
def forward(self, inputs: Sequence[torch.Tensor],
data_samples: RecSampleList, *args, **kwargs):
"""Run forward inference.
Args:
imgs (torch.Tensor | Sequence[torch.Tensor]): Image input tensor.
img_metas (Sequence[dict]): List of image information.
inputs (torch.Tensor): Image input tensor.
data_samples (list[TextRecogDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
Returns:
list[str]: Text label result of each image.
"""
text, score = self.wrapper.invoke(
img[0].contiguous().detach().cpu().numpy())
results = [dict(text=text, score=score)]
return results
text, score = self.wrapper.invoke(inputs[0].permute(
[1, 2, 0]).contiguous().detach().cpu().numpy())
pred_text = LabelData()
pred_text.score = score
pred_text.item = text
data_samples[0].pred_text = pred_text
return data_samples
def build_text_recognition_model(model_files: Sequence[str],
@ -161,5 +167,6 @@ def build_text_recognition_model(model_files: Sequence[str],
deploy_cfg=deploy_cfg,
model_cfg=model_cfg,
**kwargs))
backend_text_recognizer = backend_text_recognizer.to(device)
return backend_text_recognizer

View File

@ -175,13 +175,23 @@ models:
- name: DBNet
metafile: configs/textdet/dbnet/metafile.yml
model_configs:
- configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py
- configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py
pipelines:
- *pipeline_ts_detection_fp32
- *pipeline_ort_detection_dynamic_fp32
- *pipeline_trt_detection_dynamic_fp16
- *pipeline_ncnn_detection_static_fp32
- *pipeline_pplnn_detection_dynamic_fp32
- *pipeline_openvino_detection_dynamic_fp32
- name: PANet
metafile: configs/textdet/panet/metafile.yml
model_configs:
- configs/textdet/panet/panet_resnet18_fpem-ffm_600e_icdar2015.py
pipelines:
- *pipeline_ts_detection_fp32
- *pipeline_ort_detection_dynamic_fp32
# - *pipeline_trt_detection_dynamic_fp32
- *pipeline_trt_detection_dynamic_fp16
# - *pipeline_trt_detection_dynamic_int8
- *pipeline_ncnn_detection_static_fp32
- *pipeline_pplnn_detection_dynamic_fp32
- *pipeline_openvino_detection_dynamic_fp32
@ -189,17 +199,10 @@ models:
- name: CRNN
metafile: configs/textrecog/crnn/metafile.yml
model_configs:
- configs/textrecog/crnn/crnn_academic_dataset.py
- configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py
pipelines:
- *pipeline_ts_recognition_fp32
- *pipeline_ort_recognition_dynamic_fp32
- *pipeline_trt_recognition_dynamic_fp16
- *pipeline_ncnn_recognition_static_fp32
- *pipeline_pplnn_recognition_dynamic_fp32
- name: SAR
metafile: configs/textrecog/sar/metafile.yml
model_configs:
- configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
pipelines:
- *pipeline_ort_recognition_dynamic_fp32

View File

@ -0,0 +1,36 @@
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z

View File

@ -6,7 +6,8 @@ from copy import deepcopy
from mmengine import DictAction
from mmdeploy.apis import build_task_processor
from mmdeploy.utils.config_utils import load_config
from mmdeploy.utils.config_utils import get_backend, load_config
from mmdeploy.utils.constants import Backend
from mmdeploy.utils.timer import TimeCounter
@ -88,6 +89,8 @@ def main():
# load deploy_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
if get_backend(deploy_cfg) == Backend.SDK:
model_cfg.test_cfg.type = 'DeployTestLoop'
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None: