mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
parent
a335444a49
commit
f56a30025a
@ -1,4 +1,3 @@
|
||||
exclude: ^tests/data/
|
||||
repos:
|
||||
- repo: https://gitlab.com/pycqa/flake8.git
|
||||
rev: 3.8.3
|
||||
@ -40,6 +39,10 @@ repos:
|
||||
"-t",
|
||||
"allow_different_nesting",
|
||||
]
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.1.0
|
||||
hooks:
|
||||
- id: codespell
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
hooks:
|
||||
|
@ -250,7 +250,7 @@ void GridSampleKernel::Compute(OrtKernelContext *context) {
|
||||
int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
|
||||
int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
|
||||
|
||||
// assign nearest neighor pixel value to output pixel
|
||||
// assign nearest neighbor pixel value to output pixel
|
||||
float *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
|
||||
const float *inp_ptr_NC = inp_ptr_N;
|
||||
for (int64_t c = 0; c < C;
|
||||
@ -285,7 +285,7 @@ void GridSampleKernel::Compute(OrtKernelContext *context) {
|
||||
++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
|
||||
float coefficients[4];
|
||||
|
||||
// Interpolate 4 values in the x directon
|
||||
// Interpolate 4 values in the x direction
|
||||
for (int64_t i = 0; i < 4; ++i) {
|
||||
coefficients[i] = cubic_interp1d<float>(
|
||||
get_value_bounded<float>(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i,
|
||||
|
@ -69,7 +69,7 @@ template <typename T_BBOX>
|
||||
__device__ void emptyBboxInfo(BboxInfo<T_BBOX> *bbox_info) {
|
||||
bbox_info->conf_score = T_BBOX(0);
|
||||
bbox_info->label =
|
||||
-2; // -1 is used for all labels when shared_location is ture
|
||||
-2; // -1 is used for all labels when shared_location is true
|
||||
bbox_info->bbox_idx = -1;
|
||||
bbox_info->kept = false;
|
||||
}
|
||||
|
@ -124,14 +124,14 @@ void TRTBatchedNMS::configurePlugin(
|
||||
}
|
||||
|
||||
bool TRTBatchedNMS::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
|
||||
int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos == 3) {
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
return ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
return ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMS::getPluginType() const TRT_NOEXCEPT {
|
||||
|
@ -42,7 +42,7 @@ class TRTBatchedNMS : public TRTPluginBase {
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc* inOut,
|
||||
const nvinfer1::PluginTensorDesc* ioDesc,
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
|
||||
|
@ -51,14 +51,14 @@ nvinfer1::DimsExprs TRTGridSampler::getOutputDimensions(
|
||||
}
|
||||
|
||||
bool TRTGridSampler::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos == 0) {
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
} else {
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
return ioDesc[pos].type == ioDesc[0].type &&
|
||||
ioDesc[pos].format == ioDesc[0].format;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,7 @@ class TRTGridSampler : public TRTPluginBase {
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
const nvinfer1::PluginTensorDesc *ioDesc,
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
|
||||
|
@ -144,7 +144,7 @@ __global__ void grid_sampler_2d_kernel(
|
||||
const int n = index / (out_H * out_W);
|
||||
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y co-ordinates from grid
|
||||
// get the corresponding input x, y coordinates from grid
|
||||
scalar_t ix = grid[grid_offset];
|
||||
scalar_t iy = grid[grid_offset + grid_sCoor];
|
||||
|
||||
@ -193,7 +193,7 @@ __global__ void grid_sampler_2d_kernel(
|
||||
int ix_nearest = static_cast<int>(::round(ix));
|
||||
int iy_nearest = static_cast<int>(::round(iy));
|
||||
|
||||
// assign nearest neighor pixel value to output pixel
|
||||
// assign nearest neighbor pixel value to output pixel
|
||||
auto inp_ptr_NC = input + n * inp_sN;
|
||||
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
|
||||
for (int c = 0; c < C;
|
||||
@ -245,7 +245,7 @@ __global__ void grid_sampler_3d_kernel(
|
||||
const int grid_offset =
|
||||
n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y, z co-ordinates from grid
|
||||
// get the corresponding input x, y, z coordinates from grid
|
||||
scalar_t ix = grid[grid_offset];
|
||||
scalar_t iy = grid[grid_offset + grid_sCoor];
|
||||
scalar_t iz = grid[grid_offset + 2 * grid_sCoor];
|
||||
@ -363,7 +363,7 @@ __global__ void grid_sampler_3d_kernel(
|
||||
int iy_nearest = static_cast<int>(::round(iy));
|
||||
int iz_nearest = static_cast<int>(::round(iz));
|
||||
|
||||
// assign nearest neighor pixel value to output pixel
|
||||
// assign nearest neighbor pixel value to output pixel
|
||||
auto inp_ptr_NC = input + n * inp_sN;
|
||||
auto out_ptr_NCDHW =
|
||||
output + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
|
||||
|
@ -107,12 +107,12 @@ void TRTInstanceNormalization::serialize(void* buffer) const TRT_NOEXCEPT {
|
||||
}
|
||||
|
||||
bool TRTInstanceNormalization::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
|
||||
int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
return ((inOut[pos].type == nvinfer1::DataType::kFLOAT ||
|
||||
inOut[pos].type == nvinfer1::DataType::kHALF) &&
|
||||
inOut[pos].format == nvinfer1::PluginFormat::kLINEAR &&
|
||||
inOut[pos].type == inOut[0].type);
|
||||
return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT ||
|
||||
ioDesc[pos].type == nvinfer1::DataType::kHALF) &&
|
||||
ioDesc[pos].format == nvinfer1::PluginFormat::kLINEAR &&
|
||||
ioDesc[pos].type == ioDesc[0].type);
|
||||
}
|
||||
|
||||
const char* TRTInstanceNormalization::getPluginType() const TRT_NOEXCEPT {
|
||||
|
@ -48,7 +48,7 @@ class TRTInstanceNormalization final : public TRTPluginBase {
|
||||
|
||||
// DynamicExt plugin supportsFormat update.
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc* inOut,
|
||||
const nvinfer1::PluginTensorDesc* ioDesc,
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
|
||||
|
@ -73,15 +73,15 @@ nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions(
|
||||
}
|
||||
|
||||
bool ModulatedDeformableConvPluginDynamic::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos == 0) {
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
|
||||
} else {
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
return ioDesc[pos].type == ioDesc[0].type &&
|
||||
ioDesc[pos].format == ioDesc[0].format;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -31,7 +31,7 @@ class ModulatedDeformableConvPluginDynamic : public TRTPluginBase {
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
const nvinfer1::PluginTensorDesc *ioDesc,
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
|
@ -65,10 +65,10 @@ nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions(
|
||||
}
|
||||
|
||||
bool TRTMultiLevelRoiAlign::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
return ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
void TRTMultiLevelRoiAlign::configurePlugin(
|
||||
|
@ -29,7 +29,7 @@ class TRTMultiLevelRoiAlign : public TRTPluginBase {
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
const nvinfer1::PluginTensorDesc *ioDesc,
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
|
@ -59,10 +59,10 @@ nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions(
|
||||
}
|
||||
|
||||
bool TRTRoIAlign::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
return ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
void TRTRoIAlign::configurePlugin(
|
||||
|
@ -23,7 +23,7 @@ class TRTRoIAlign : public TRTPluginBase {
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
const nvinfer1::PluginTensorDesc *ioDesc,
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
|
@ -34,24 +34,24 @@ nvinfer1::DimsExprs TRTScatterND::getOutputDimensions(
|
||||
}
|
||||
|
||||
bool TRTScatterND::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos < nbInputs) {
|
||||
switch (pos) {
|
||||
case 0:
|
||||
// data
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
|
||||
(inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
|
||||
(ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
case 1:
|
||||
// indices
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
return ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
case 2:
|
||||
// updates
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
return ioDesc[pos].type == ioDesc[0].type &&
|
||||
ioDesc[pos].format == ioDesc[0].format;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
@ -59,8 +59,8 @@ bool TRTScatterND::supportsFormatCombination(
|
||||
switch (pos - nbInputs) {
|
||||
case 0:
|
||||
// output
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
return ioDesc[pos].type == ioDesc[0].type &&
|
||||
ioDesc[pos].format == ioDesc[0].format;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ class TRTScatterND : public TRTPluginBase {
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
const nvinfer1::PluginTensorDesc *ioDesc,
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
|
@ -29,7 +29,7 @@ def torch2onnx_impl(model: torch.nn.Module, input: torch.Tensor,
|
||||
backend = get_backend(deploy_cfg).value
|
||||
opset_version = pytorch2onnx_cfg.get('opset_version', 11)
|
||||
|
||||
# load registed symbolic
|
||||
# load registered symbolic
|
||||
register_extra_symbolics(deploy_cfg, backend=backend, opset=opset_version)
|
||||
|
||||
# patch model
|
||||
|
@ -163,7 +163,7 @@ class RewriterContext(object):
|
||||
Examples:
|
||||
>>> from mmdeploy.core import RewriterContext
|
||||
>>> with RewriterContext(cfg, backend='onnxruntime'):
|
||||
>>> # the rewrite has been actived inside the context
|
||||
>>> # the rewrite has been activated inside the context
|
||||
>>> torch.onnx.export(model, inputs, onnx_file)
|
||||
"""
|
||||
|
||||
|
@ -97,7 +97,7 @@ def patch_model(model: nn.Module,
|
||||
backend: str = 'default',
|
||||
recursive: bool = True,
|
||||
**kwargs) -> nn.Module:
|
||||
"""Patch the model, replace the modules that can be rewrited.
|
||||
"""Patch the model, replace the modules that can be rewritten.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
|
@ -60,7 +60,7 @@ SYMBOLIC_REGISTER = Registry('symbolics', build_func=set_symbolic, scope=None)
|
||||
class SymbolicWrapper:
|
||||
"""The wrapper of the symbolic function.
|
||||
|
||||
The wrapper is used to pass context and enviroment to the symbolic.
|
||||
The wrapper is used to pass context and environment to the symbolic.
|
||||
|
||||
Args:
|
||||
cfg (Dict): Config dictionary of deployment.
|
||||
@ -86,7 +86,7 @@ def register_symbolic(func_name: str,
|
||||
"""The decorator of the custom symbolic.
|
||||
|
||||
Args:
|
||||
func_name (str): The function name/path to overide the symbolic.
|
||||
func_name (str): The function name/path to override the symbolic.
|
||||
backend (str): The inference engine name.
|
||||
is_pytorch (bool): Enable this flag if func_name is the name of \
|
||||
a pytorch builtin function.
|
||||
|
@ -3,7 +3,7 @@ from mmdeploy.core import SYMBOLIC_REGISTER
|
||||
|
||||
# Here using mmcv.ops.roi_align.__self__ to find
|
||||
# mmcv.ops.roi_align.RoIAlignFunction, because RoIAlignFunction is not
|
||||
# visiable in mmcv.
|
||||
# visible in mmcv.
|
||||
@SYMBOLIC_REGISTER.register_symbolic(
|
||||
'mmcv.ops.roi_align.__self__', backend='default')
|
||||
def roi_align_default(ctx, g, input, rois, output_size, spatial_scale,
|
||||
|
@ -54,7 +54,7 @@ def create_input(task: Task,
|
||||
|
||||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
datas = []
|
||||
data_list = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
if isinstance(img, np.ndarray):
|
||||
@ -65,9 +65,9 @@ def create_input(task: Task,
|
||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
datas.append(data)
|
||||
data_list.append(data)
|
||||
|
||||
data = collate(datas, samples_per_gpu=len(imgs))
|
||||
data = collate(data_list, samples_per_gpu=len(imgs))
|
||||
|
||||
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
|
||||
data['img'] = [img.data[0] for img in data['img']]
|
||||
|
@ -35,8 +35,8 @@ class DeployBaseRestorer(BaseModel):
|
||||
def forward(self, lq: torch.Tensor, test_mode: bool = False, **kwargs):
|
||||
"""Run test inference for restorer.
|
||||
|
||||
We want forward() to output an image or a evalution result.
|
||||
When test_mode is set, the output is evalution result. Otherwise
|
||||
We want forward() to output an image or a evaluation result.
|
||||
When test_mode is set, the output is evaluation result. Otherwise
|
||||
it is an image.
|
||||
|
||||
Args:
|
||||
@ -104,7 +104,7 @@ class DeployBaseRestorer(BaseModel):
|
||||
outputs: torch.Tensor,
|
||||
lq: torch.Tensor,
|
||||
gt: Optional[torch.Tensor] = None):
|
||||
"""Get evalution results by post-processing model outputs.
|
||||
"""Get evaluation results by post-processing model outputs.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor) : The output high resolution image.
|
||||
|
@ -18,7 +18,7 @@ def _preprocess_cfg(config: Union[str, mmcv.Config]):
|
||||
model_cfg (str | mmcv.Config): The input model config.
|
||||
"""
|
||||
|
||||
# TODO: Differentiate the editting tasks (e.g. restorers and mattors
|
||||
# TODO: Differentiate the editing tasks (e.g. restorers and mattors
|
||||
# preprocess the data in differenet ways)
|
||||
|
||||
keys_to_remove = ['gt', 'gt_path']
|
||||
|
@ -71,7 +71,7 @@ def create_input(task: Task,
|
||||
from mmocr.datasets import build_dataset # noqa: F401
|
||||
test_pipeline = Compose(model_cfg.data.test.pipeline)
|
||||
|
||||
datas = []
|
||||
data_list = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
if is_ndarray:
|
||||
@ -84,14 +84,14 @@ def create_input(task: Task,
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
# get tensor from list to stack for batch mode (text detection)
|
||||
datas.append(data)
|
||||
data_list.append(data)
|
||||
|
||||
if isinstance(datas[0]['img'], list) and len(datas) > 1:
|
||||
if isinstance(data_list[0]['img'], list) and len(data_list) > 1:
|
||||
raise Exception('aug test does not support '
|
||||
f'inference with batch size '
|
||||
f'{len(datas)}')
|
||||
f'{len(data_list)}')
|
||||
|
||||
data = collate(datas, samples_per_gpu=len(imgs))
|
||||
data = collate(data_list, samples_per_gpu=len(imgs))
|
||||
|
||||
# process img_metas
|
||||
if isinstance(data['img_metas'], list):
|
||||
|
@ -47,15 +47,15 @@ def create_input(task: Task,
|
||||
cfg.data.test.pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
||||
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
datas = []
|
||||
data_list = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
data = dict(img=img)
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
datas.append(data)
|
||||
data_list.append(data)
|
||||
|
||||
data = collate(datas, samples_per_gpu=len(imgs))
|
||||
data = collate(data_list, samples_per_gpu=len(imgs))
|
||||
|
||||
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
|
||||
data['img'] = [img.data[0][None, :] for img in data['img']]
|
||||
|
@ -34,7 +34,7 @@ def get_task_type(deploy_cfg: Union[str, mmcv.Config], default=None) -> Task:
|
||||
|
||||
Args:
|
||||
deploy_cfg (str | mmcv.Config): The path or content of config.
|
||||
default (str): If the "task" field of config is emtpy, then return
|
||||
default (str): If the "task" field of config is empty, then return
|
||||
default task type.
|
||||
|
||||
Returns:
|
||||
@ -56,7 +56,7 @@ def get_codebase(deploy_cfg: Union[str, mmcv.Config],
|
||||
|
||||
Args:
|
||||
deploy_cfg (str | mmcv.Config): The path or content of config.
|
||||
default (str): If the "codebase" field of config is emtpy, then return
|
||||
default (str): If the "codebase" field of config is empty, then return
|
||||
default codebase type.
|
||||
|
||||
Returns:
|
||||
@ -77,7 +77,7 @@ def get_backend(deploy_cfg: Union[str, mmcv.Config], default=None) -> Backend:
|
||||
|
||||
Args:
|
||||
deploy_cfg (str | mmcv.Config): The path or content of config.
|
||||
default (str): If the "backend" field of config is emtpy, then return
|
||||
default (str): If the "backend" field of config is empty, then return
|
||||
default backend type.
|
||||
|
||||
Returns:
|
||||
|
@ -88,7 +88,7 @@ class TimeCounter:
|
||||
registried function will be activated.
|
||||
warmup (int): the warm up steps, default 1.
|
||||
log_interval (int): interval between each log, default 1.
|
||||
with_sync (bool): wheather use cuda synchronize for time counting,
|
||||
with_sync (bool): whether use cuda synchronize for time counting,
|
||||
default False.
|
||||
"""
|
||||
assert warmup >= 1
|
||||
|
@ -97,7 +97,7 @@ def test_module_rewriter():
|
||||
rewrited_result = rewrited_bottle_nect(x)
|
||||
torch.testing.assert_allclose(rewrited_result, result * 2)
|
||||
|
||||
# wrong backend should not be rewrited
|
||||
# wrong backend should not be rewritten
|
||||
|
||||
rewrited_model = patch_model(model, cfg=cfg)
|
||||
rewrited_bottle_nect = rewrited_model.layer1[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user