Remove useless & Rename

pull/1178/head
wangxinyu 2022-07-14 11:23:40 +00:00 committed by gaotongxiao
parent 567aec5390
commit d8c3aeff3a
13 changed files with 2 additions and 843 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .kie_data_sample import KIEDataSample
from .textdet_data_sample import TextDetDataSample
from .textrecog_data_element import TextRecogDataSample
from .textrecog_data_sample import TextRecogDataSample
__all__ = ['TextDetDataSample', 'TextRecogDataSample', 'KIEDataSample']

View File

@ -1,12 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # NOQA
from .convertors import * # NOQA
from .data_preprocessors import * # NOQA
from .decoders import * # NOQA
from .dictionary import * # NOQA
from .encoders import * # NOQA
from .heads import * # NOQA
from .module_losses import * # NOQA
from .plugins import * # NOQA
from .postprocessors import * # NOQA
from .preprocessors import * # NOQA

View File

@ -1,11 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .abi import ABIConvertor
from .attn import AttnConvertor
from .base import BaseConvertor
from .ctc import CTCConvertor
from .seg import SegConvertor
__all__ = [
'BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor',
'ABIConvertor'
]

View File

@ -1,68 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import mmocr.utils as utils
from mmocr.registry import MODELS
from .attn import AttnConvertor
@MODELS.register_module()
class ABIConvertor(AttnConvertor):
"""Convert between text, index and tensor for encoder-decoder based
pipeline. Modified from AttnConvertor to get closer to ABINet's original
implementation.
Args:
dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}.
dict_file (None|str): Character dict file path. If not none,
higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, higher
priority than dict_type, but lower than dict_file.
with_unknown (bool): If True, add `UKN` token to class.
max_seq_len (int): Maximum sequence length of label.
lower (bool): If True, convert original string to lower case.
start_end_same (bool): Whether use the same index for
start and end token or not. Default: True.
"""
def str2tensor(self, strings):
"""
Convert text-string into tensor. Different from
:obj:`mmocr.models.textrecog.convertors.AttnConvertor`, the targets
field returns target index no longer than max_seq_len (EOS token
included).
Args:
strings (list[str]): For instance, ['hello', 'world']
Returns:
dict: A dict with two tensors.
- | targets (list[Tensor]): [torch.Tensor([1,2,3,3,4,8]),
torch.Tensor([5,4,6,3,7,8])]
- | padded_targets (Tensor): Tensor of shape
(bsz * max_seq_len)).
"""
assert utils.is_type_list(strings, str)
tensors, padded_targets = [], []
indexes = self.str2idx(strings)
for index in indexes:
tensor = torch.LongTensor(index[:self.max_seq_len - 1] +
[self.end_idx])
tensors.append(tensor)
# target tensor for loss
src_target = torch.LongTensor(tensor.size(0) + 1).fill_(0)
src_target[0] = self.start_idx
src_target[1:] = tensor
padded_target = (torch.ones(self.max_seq_len) *
self.padding_idx).long()
char_num = src_target.size(0)
if char_num > self.max_seq_len:
padded_target = src_target[:self.max_seq_len]
else:
padded_target[:char_num] = src_target
padded_targets.append(padded_target)
padded_targets = torch.stack(padded_targets, 0).long()
return {'targets': tensors, 'padded_targets': padded_targets}

View File

@ -1,142 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import mmocr.utils as utils
from mmocr.registry import MODELS
from .base import BaseConvertor
@MODELS.register_module()
class AttnConvertor(BaseConvertor):
"""Convert between text, index and tensor for encoder-decoder based
pipeline.
Args:
dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}.
dict_file (None|str): Character dict file path. If not none,
higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, higher
priority than dict_type, but lower than dict_file.
with_unknown (bool): If True, add `UKN` token to class.
max_seq_len (int): Maximum sequence length of label.
lower (bool): If True, convert original string to lower case.
start_end_same (bool): Whether use the same index for
start and end token or not. Default: True.
"""
def __init__(self,
dict_type='DICT90',
dict_file=None,
dict_list=None,
with_unknown=True,
max_seq_len=40,
lower=False,
start_end_same=True,
**kwargs):
super().__init__(dict_type, dict_file, dict_list)
assert isinstance(with_unknown, bool)
assert isinstance(max_seq_len, int)
assert isinstance(lower, bool)
self.with_unknown = with_unknown
self.max_seq_len = max_seq_len
self.lower = lower
self.start_end_same = start_end_same
self.update_dict()
def update_dict(self):
start_end_token = '<BOS/EOS>'
unknown_token = '<UKN>'
padding_token = '<PAD>'
# unknown
self.unknown_idx = None
if self.with_unknown:
self.idx2char.append(unknown_token)
self.unknown_idx = len(self.idx2char) - 1
# BOS/EOS
self.idx2char.append(start_end_token)
self.start_idx = len(self.idx2char) - 1
if not self.start_end_same:
self.idx2char.append(start_end_token)
self.end_idx = len(self.idx2char) - 1
# padding
self.idx2char.append(padding_token)
self.padding_idx = len(self.idx2char) - 1
# update char2idx
self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
def str2tensor(self, strings):
"""
Convert text-string into tensor.
Args:
strings (list[str]): ['hello', 'world']
Returns:
dict (str: Tensor | list[tensor]):
tensors (list[Tensor]): [torch.Tensor([1,2,3,3,4]),
torch.Tensor([5,4,6,3,7])]
padded_targets (Tensor(bsz * max_seq_len))
"""
assert utils.is_type_list(strings, str)
tensors, padded_targets = [], []
indexes = self.str2idx(strings)
for index in indexes:
tensor = torch.LongTensor(index)
tensors.append(tensor)
# target tensor for loss
src_target = torch.LongTensor(tensor.size(0) + 2).fill_(0)
src_target[-1] = self.end_idx
src_target[0] = self.start_idx
src_target[1:-1] = tensor
padded_target = (torch.ones(self.max_seq_len) *
self.padding_idx).long()
char_num = src_target.size(0)
if char_num > self.max_seq_len:
padded_target = src_target[:self.max_seq_len]
else:
padded_target[:char_num] = src_target
padded_targets.append(padded_target)
padded_targets = torch.stack(padded_targets, 0).long()
return {'targets': tensors, 'padded_targets': padded_targets}
def tensor2idx(self, outputs, img_metas=None):
"""
Convert output tensor to text-index
Args:
outputs (tensor): model outputs with size: N * T * C
img_metas (list[dict]): Each dict contains one image info.
Returns:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
[0.9,0.9,0.98,0.97,0.96]]
"""
batch_size = outputs.size(0)
ignore_indexes = [self.padding_idx]
indexes, scores = [], []
for idx in range(batch_size):
seq = outputs[idx, :, :]
seq = seq.softmax(dim=-1)
max_value, max_idx = torch.max(seq, -1)
str_index, str_score = [], []
output_index = max_idx.cpu().detach().numpy().tolist()
output_score = max_value.cpu().detach().numpy().tolist()
for char_index, char_score in zip(output_index, output_score):
if char_index in ignore_indexes:
continue
if char_index == self.end_idx:
break
str_index.append(char_index)
str_score.append(char_score)
indexes.append(str_index)
scores.append(str_score)
return indexes, scores

View File

@ -1,128 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.registry import MODELS
from mmocr.utils import list_from_file
@MODELS.register_module()
class BaseConvertor:
"""Convert between text, index and tensor for text recognize pipeline.
Args:
dict_type (str): Type of dict, options are 'DICT36', 'DICT37', 'DICT90'
and 'DICT91'.
dict_file (None|str): Character dict file path. If not none,
the dict_file is of higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, the list
is of higher priority than dict_type, but lower than dict_file.
"""
start_idx = end_idx = padding_idx = 0
unknown_idx = None
lower = False
dicts = dict(
DICT36=tuple('0123456789abcdefghijklmnopqrstuvwxyz'),
DICT90=tuple('0123456789abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
'*+,-./:;<=>?@[\\]_`~'),
# With space character
DICT37=tuple('0123456789abcdefghijklmnopqrstuvwxyz '),
DICT91=tuple('0123456789abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
'*+,-./:;<=>?@[\\]_`~ '))
def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None):
assert dict_file is None or isinstance(dict_file, str)
assert dict_list is None or isinstance(dict_list, list)
self.idx2char = []
if dict_file is not None:
for line_num, line in enumerate(list_from_file(dict_file)):
line = line.strip('\r\n')
if len(line) > 1:
raise ValueError('Expect each line has 0 or 1 character, '
f'got {len(line)} characters '
f'at line {line_num + 1}')
if line != '':
self.idx2char.append(line)
elif dict_list is not None:
self.idx2char = list(dict_list)
else:
if dict_type in self.dicts:
self.idx2char = list(self.dicts[dict_type])
else:
raise NotImplementedError(f'Dict type {dict_type} is not '
'supported')
assert len(set(self.idx2char)) == len(self.idx2char), \
'Invalid dictionary: Has duplicated characters.'
self.char2idx = {char: idx for idx, char in enumerate(self.idx2char)}
def num_classes(self):
"""Number of output classes."""
return len(self.idx2char)
def str2idx(self, strings):
"""Convert strings to indexes.
Args:
strings (list[str]): ['hello', 'world'].
Returns:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
"""
assert isinstance(strings, list)
indexes = []
for string in strings:
if self.lower:
string = string.lower()
index = []
for char in string:
char_idx = self.char2idx.get(char, self.unknown_idx)
if char_idx is None:
raise Exception(f'Chararcter: {char} not in dict,'
f' please check gt_label and use'
f' custom dict file,'
f' or set "with_unknown=True"')
index.append(char_idx)
indexes.append(index)
return indexes
def str2tensor(self, strings):
"""Convert text-string to input tensor.
Args:
strings (list[str]): ['hello', 'world'].
Returns:
tensors (list[torch.Tensor]): [torch.Tensor([1,2,3,3,4]),
torch.Tensor([5,4,6,3,7])].
"""
raise NotImplementedError
def idx2str(self, indexes):
"""Convert indexes to text strings.
Args:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
Returns:
strings (list[str]): ['hello', 'world'].
"""
assert isinstance(indexes, list)
strings = []
for index in indexes:
string = [self.idx2char[i] for i in index]
strings.append(''.join(string))
return strings
def tensor2idx(self, output):
"""Convert model output tensor to character indexes and scores.
Args:
output (tensor): The model outputs with size: N * T * C
Returns:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
[0.9,0.9,0.98,0.97,0.96]].
"""
raise NotImplementedError

View File

@ -1,145 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.registry import MODELS
from .base import BaseConvertor
@MODELS.register_module()
class CTCConvertor(BaseConvertor):
"""Convert between text, index and tensor for CTC loss-based pipeline.
Args:
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
dict_file (None|str): Character dict file path. If not none, the file
is of higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, the list
is of higher priority than dict_type, but lower than dict_file.
with_unknown (bool): If True, add `UKN` token to class.
lower (bool): If True, convert original string to lower case.
"""
def __init__(self,
dict_type='DICT90',
dict_file=None,
dict_list=None,
with_unknown=True,
lower=False,
**kwargs):
super().__init__(dict_type, dict_file, dict_list)
assert isinstance(with_unknown, bool)
assert isinstance(lower, bool)
self.with_unknown = with_unknown
self.lower = lower
self.update_dict()
def update_dict(self):
# CTC-blank
blank_token = '<BLK>'
self.blank_idx = 0
self.idx2char.insert(0, blank_token)
# unknown
self.unknown_idx = None
if self.with_unknown:
self.idx2char.append('<UKN>')
self.unknown_idx = len(self.idx2char) - 1
# update char2idx
self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
def str2tensor(self, strings):
"""Convert text-string to ctc-loss input tensor.
Args:
strings (list[str]): ['hello', 'world'].
Returns:
dict (str: tensor | list[tensor]):
tensors (list[tensor]): [torch.Tensor([1,2,3,3,4]),
torch.Tensor([5,4,6,3,7])].
flatten_targets (tensor): torch.Tensor([1,2,3,3,4,5,4,6,3,7]).
target_lengths (tensor): torch.IntTensot([5,5]).
"""
assert utils.is_type_list(strings, str)
tensors = []
indexes = self.str2idx(strings)
for index in indexes:
tensor = torch.IntTensor(index)
tensors.append(tensor)
target_lengths = torch.IntTensor([len(t) for t in tensors])
flatten_target = torch.cat(tensors)
return {
'targets': tensors,
'flatten_targets': flatten_target,
'target_lengths': target_lengths
}
def tensor2idx(self, output, img_metas, topk=1, return_topk=False):
"""Convert model output tensor to index-list.
Args:
output (tensor): The model outputs with size: N * T * C.
img_metas (list[dict]): Each dict contains one image info.
topk (int): The highest k classes to be returned.
return_topk (bool): Whether to return topk or just top1.
Returns:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
[0.9,0.9,0.98,0.97,0.96]]
(
indexes_topk (list[list[list[int]->len=topk]]):
scores_topk (list[list[list[float]->len=topk]])
).
"""
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == output.size(0)
assert isinstance(topk, int)
assert topk >= 1
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
batch_size = output.size(0)
output = F.softmax(output, dim=2)
output = output.cpu().detach()
batch_topk_value, batch_topk_idx = output.topk(topk, dim=2)
batch_max_idx = batch_topk_idx[:, :, 0]
scores_topk, indexes_topk = [], []
scores, indexes = [], []
feat_len = output.size(1)
for b in range(batch_size):
valid_ratio = valid_ratios[b]
decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
pred = batch_max_idx[b, :]
select_idx = []
prev_idx = self.blank_idx
for t in range(decode_len):
tmp_value = pred[t].item()
if tmp_value not in (prev_idx, self.blank_idx):
select_idx.append(t)
prev_idx = tmp_value
select_idx = torch.LongTensor(select_idx)
topk_value = torch.index_select(batch_topk_value[b, :, :], 0,
select_idx) # valid_seqlen * topk
topk_idx = torch.index_select(batch_topk_idx[b, :, :], 0,
select_idx)
topk_idx_list, topk_value_list = topk_idx.numpy().tolist(
), topk_value.numpy().tolist()
indexes_topk.append(topk_idx_list)
scores_topk.append(topk_value_list)
indexes.append([x[0] for x in topk_idx_list])
scores.append([x[0] for x in topk_value_list])
if return_topk:
return indexes_topk, scores_topk
return indexes, scores

View File

@ -1,127 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
import torch
import mmocr.utils as utils
from mmocr.registry import MODELS
from .base import BaseConvertor
@MODELS.register_module()
class SegConvertor(BaseConvertor):
"""Convert between text, index and tensor for segmentation based pipeline.
Args:
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
dict_file (None|str): Character dict file path. If not none, the
file is of higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, the list
is of higher priority than dict_type, but lower than dict_file.
with_unknown (bool): If True, add `UKN` token to class.
lower (bool): If True, convert original string to lower case.
"""
def __init__(self,
dict_type='DICT36',
dict_file=None,
dict_list=None,
with_unknown=True,
lower=False,
**kwargs):
super().__init__(dict_type, dict_file, dict_list)
assert isinstance(with_unknown, bool)
assert isinstance(lower, bool)
self.with_unknown = with_unknown
self.lower = lower
self.update_dict()
def update_dict(self):
# background
self.idx2char.insert(0, '<BG>')
# unknown
self.unknown_idx = None
if self.with_unknown:
self.idx2char.append('<UKN>')
self.unknown_idx = len(self.idx2char) - 1
# update char2idx
self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
def tensor2str(self, output, img_metas=None):
"""Convert model output tensor to string labels.
Args:
output (tensor): Model outputs with size: N * C * H * W
img_metas (list[dict]): Each dict contains one image info.
Returns:
texts (list[str]): Decoded text labels.
scores (list[list[float]]): Decoded chars scores.
"""
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == output.size(0)
texts, scores = [], []
for b in range(output.size(0)):
seg_pred = output[b].detach()
valid_width = int(
output.size(-1) * img_metas[b]['valid_ratio'] + 1)
seg_res = torch.argmax(
seg_pred[:, :, :valid_width],
dim=0).cpu().numpy().astype(np.int32)
seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8)
_, labels, stats, centroids = cv2.connectedComponentsWithStats(
seg_thr)
component_num = stats.shape[0]
all_res = []
for i in range(component_num):
temp_loc = (labels == i)
temp_value = seg_res[temp_loc]
temp_center = centroids[i]
temp_max_num = 0
temp_max_cls = -1
temp_total_num = 0
for c in range(len(self.idx2char)):
c_num = np.sum(temp_value == c)
temp_total_num += c_num
if c_num > temp_max_num:
temp_max_num = c_num
temp_max_cls = c
if temp_max_cls == 0:
continue
temp_max_score = 1.0 * temp_max_num / temp_total_num
all_res.append(
[temp_max_cls, temp_center, temp_max_num, temp_max_score])
all_res = sorted(all_res, key=lambda s: s[1][0])
chars, char_scores = [], []
for res in all_res:
temp_area = res[2]
if temp_area < 20:
continue
temp_char_index = res[0]
if temp_char_index >= len(self.idx2char):
temp_char = ''
elif temp_char_index <= 0:
temp_char = ''
elif temp_char_index == self.unknown_idx:
temp_char = ''
else:
temp_char = self.idx2char[temp_char_index]
chars.append(temp_char)
char_scores.append(res[3])
text = ''.join(chars)
texts.append(text)
scores.append(char_scores)
return texts, scores

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .seg_head import SegHead
__all__ = ['SegHead']

View File

@ -1,64 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn
from mmocr.registry import MODELS
@MODELS.register_module()
class SegHead(BaseModule):
"""Head for segmentation based text recognition.
Args:
in_channels (int): Number of input channels :math:`C`.
num_classes (int): Number of output classes :math:`C_{out}`.
upsample_param (dict | None): Config dict for interpolation layer.
Default: ``dict(scale_factor=1.0, mode='nearest')``
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self,
in_channels=128,
num_classes=37,
upsample_param=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert isinstance(num_classes, int)
assert num_classes > 0
assert upsample_param is None or isinstance(upsample_param, dict)
self.upsample_param = upsample_param
self.seg_conv = ConvModule(
in_channels,
in_channels,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN'))
# prediction
self.pred_conv = nn.Conv2d(
in_channels, num_classes, kernel_size=1, stride=1, padding=0)
def forward(self, out_neck):
"""
Args:
out_neck (list[Tensor]): A list of tensor of shape
:math:`(N, C_i, H_i, W_i)`. The network only uses the last one
(``out_neck[-1]``).
Returns:
Tensor: A tensor of shape :math:`(N, C_{out}, kH, kW)` where
:math:`k` is determined by ``upsample_param``.
"""
seg_map = self.seg_conv(out_neck[-1])
seg_map = self.pred_conv(seg_map)
if self.upsample_param is not None:
seg_map = F.interpolate(seg_map, **self.upsample_param)
return seg_map

View File

@ -8,9 +8,8 @@ from .nrtr import NRTR
from .robust_scanner import RobustScanner
from .sar import SARNet
from .satrn import SATRN
from .seg_recognizer import SegRecognizer
__all__ = [
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR',
'SegRecognizer', 'RobustScanner', 'SATRN', 'ABINet', 'MASTER'
'RobustScanner', 'SATRN', 'ABINet', 'MASTER'
]

View File

@ -1,148 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmocr.registry import MODELS
from .base import BaseRecognizer
@MODELS.register_module()
class SegRecognizer(BaseRecognizer):
"""Base class for segmentation based recognizer."""
def __init__(self,
preprocessor=None,
backbone=None,
neck=None,
head=None,
loss=None,
label_convertor=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
# Label_convertor
assert label_convertor is not None
self.label_convertor = MODELS.build(label_convertor)
# Preprocessor module, e.g., TPS
self.preprocessor = None
if preprocessor is not None:
self.preprocessor = MODELS.build(preprocessor)
# Backbone
assert backbone is not None
self.backbone = MODELS.build(backbone)
# Neck
assert neck is not None
self.neck = MODELS.build(neck)
# Head
assert head is not None
head.update(num_classes=self.label_convertor.num_classes())
self.head = MODELS.build(head)
# Loss
assert loss is not None
self.loss = MODELS.build(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
def extract_feat(self, img):
"""Directly extract features from the backbone."""
if self.preprocessor is not None:
img = self.preprocessor(img)
x = self.backbone(img)
return x
def forward_train(self, img, img_metas, gt_kernels=None):
"""
Args:
img (tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
contains: 'img_shape', 'filename', and may also contain
'ori_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
feats = self.extract_feat(img)
out_neck = self.neck(feats)
out_head = self.head(out_neck)
loss_inputs = (out_neck, out_head, gt_kernels)
losses = self.loss(*loss_inputs)
return losses
def simple_test(self, img, img_metas, **kwargs):
"""Test function without test time augmentation.
Args:
imgs (torch.Tensor): Image input tensor.
img_metas (list[dict]): List of image information.
Returns:
list[str]: Text label result of each image.
"""
feat = self.extract_feat(img)
out_neck = self.neck(feat)
out_head = self.head(out_neck)
for img_meta in img_metas:
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
img_meta['valid_ratio'] = valid_ratio
texts, scores = self.label_convertor.tensor2str(out_head, img_metas)
# flatten batch results
results = []
for text, score in zip(texts, scores):
results.append(dict(text=text, score=score))
return results
def merge_aug_results(self, aug_results):
out_text, out_score = '', -1
for result in aug_results:
text = result[0]['text']
score = sum(result[0]['score']) / max(1, len(text))
if score > out_score:
out_text = text
out_score = score
out_results = [dict(text=out_text, score=out_score)]
return out_results
def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation.
Args:
imgs (list[tensor]): Tensor should have shape NxCxHxW,
which contains all images in the batch.
img_metas (list[list[dict]]): The metadata of images.
"""
aug_results = []
for img, img_meta in zip(imgs, img_metas):
result = self.simple_test(img, img_meta, **kwargs)
aug_results.append(result)
return self.merge_aug_results(aug_results)