mirror of https://github.com/open-mmlab/mmocr.git
Remove useless & Rename
parent
567aec5390
commit
d8c3aeff3a
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .kie_data_sample import KIEDataSample
|
||||
from .textdet_data_sample import TextDetDataSample
|
||||
from .textrecog_data_element import TextRecogDataSample
|
||||
from .textrecog_data_sample import TextRecogDataSample
|
||||
|
||||
__all__ = ['TextDetDataSample', 'TextRecogDataSample', 'KIEDataSample']
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .backbones import * # NOQA
|
||||
from .convertors import * # NOQA
|
||||
from .data_preprocessors import * # NOQA
|
||||
from .decoders import * # NOQA
|
||||
from .dictionary import * # NOQA
|
||||
from .encoders import * # NOQA
|
||||
from .heads import * # NOQA
|
||||
from .module_losses import * # NOQA
|
||||
from .plugins import * # NOQA
|
||||
from .postprocessors import * # NOQA
|
||||
from .preprocessors import * # NOQA
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abi import ABIConvertor
|
||||
from .attn import AttnConvertor
|
||||
from .base import BaseConvertor
|
||||
from .ctc import CTCConvertor
|
||||
from .seg import SegConvertor
|
||||
|
||||
__all__ = [
|
||||
'BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor',
|
||||
'ABIConvertor'
|
||||
]
|
|
@ -1,68 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.registry import MODELS
|
||||
from .attn import AttnConvertor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ABIConvertor(AttnConvertor):
|
||||
"""Convert between text, index and tensor for encoder-decoder based
|
||||
pipeline. Modified from AttnConvertor to get closer to ABINet's original
|
||||
implementation.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}.
|
||||
dict_file (None|str): Character dict file path. If not none,
|
||||
higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, higher
|
||||
priority than dict_type, but lower than dict_file.
|
||||
with_unknown (bool): If True, add `UKN` token to class.
|
||||
max_seq_len (int): Maximum sequence length of label.
|
||||
lower (bool): If True, convert original string to lower case.
|
||||
start_end_same (bool): Whether use the same index for
|
||||
start and end token or not. Default: True.
|
||||
"""
|
||||
|
||||
def str2tensor(self, strings):
|
||||
"""
|
||||
Convert text-string into tensor. Different from
|
||||
:obj:`mmocr.models.textrecog.convertors.AttnConvertor`, the targets
|
||||
field returns target index no longer than max_seq_len (EOS token
|
||||
included).
|
||||
|
||||
Args:
|
||||
strings (list[str]): For instance, ['hello', 'world']
|
||||
|
||||
Returns:
|
||||
dict: A dict with two tensors.
|
||||
|
||||
- | targets (list[Tensor]): [torch.Tensor([1,2,3,3,4,8]),
|
||||
torch.Tensor([5,4,6,3,7,8])]
|
||||
- | padded_targets (Tensor): Tensor of shape
|
||||
(bsz * max_seq_len)).
|
||||
"""
|
||||
assert utils.is_type_list(strings, str)
|
||||
|
||||
tensors, padded_targets = [], []
|
||||
indexes = self.str2idx(strings)
|
||||
for index in indexes:
|
||||
tensor = torch.LongTensor(index[:self.max_seq_len - 1] +
|
||||
[self.end_idx])
|
||||
tensors.append(tensor)
|
||||
# target tensor for loss
|
||||
src_target = torch.LongTensor(tensor.size(0) + 1).fill_(0)
|
||||
src_target[0] = self.start_idx
|
||||
src_target[1:] = tensor
|
||||
padded_target = (torch.ones(self.max_seq_len) *
|
||||
self.padding_idx).long()
|
||||
char_num = src_target.size(0)
|
||||
if char_num > self.max_seq_len:
|
||||
padded_target = src_target[:self.max_seq_len]
|
||||
else:
|
||||
padded_target[:char_num] = src_target
|
||||
padded_targets.append(padded_target)
|
||||
padded_targets = torch.stack(padded_targets, 0).long()
|
||||
|
||||
return {'targets': tensors, 'padded_targets': padded_targets}
|
|
@ -1,142 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BaseConvertor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class AttnConvertor(BaseConvertor):
|
||||
"""Convert between text, index and tensor for encoder-decoder based
|
||||
pipeline.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}.
|
||||
dict_file (None|str): Character dict file path. If not none,
|
||||
higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, higher
|
||||
priority than dict_type, but lower than dict_file.
|
||||
with_unknown (bool): If True, add `UKN` token to class.
|
||||
max_seq_len (int): Maximum sequence length of label.
|
||||
lower (bool): If True, convert original string to lower case.
|
||||
start_end_same (bool): Whether use the same index for
|
||||
start and end token or not. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dict_type='DICT90',
|
||||
dict_file=None,
|
||||
dict_list=None,
|
||||
with_unknown=True,
|
||||
max_seq_len=40,
|
||||
lower=False,
|
||||
start_end_same=True,
|
||||
**kwargs):
|
||||
super().__init__(dict_type, dict_file, dict_list)
|
||||
assert isinstance(with_unknown, bool)
|
||||
assert isinstance(max_seq_len, int)
|
||||
assert isinstance(lower, bool)
|
||||
|
||||
self.with_unknown = with_unknown
|
||||
self.max_seq_len = max_seq_len
|
||||
self.lower = lower
|
||||
self.start_end_same = start_end_same
|
||||
|
||||
self.update_dict()
|
||||
|
||||
def update_dict(self):
|
||||
start_end_token = '<BOS/EOS>'
|
||||
unknown_token = '<UKN>'
|
||||
padding_token = '<PAD>'
|
||||
|
||||
# unknown
|
||||
self.unknown_idx = None
|
||||
if self.with_unknown:
|
||||
self.idx2char.append(unknown_token)
|
||||
self.unknown_idx = len(self.idx2char) - 1
|
||||
|
||||
# BOS/EOS
|
||||
self.idx2char.append(start_end_token)
|
||||
self.start_idx = len(self.idx2char) - 1
|
||||
if not self.start_end_same:
|
||||
self.idx2char.append(start_end_token)
|
||||
self.end_idx = len(self.idx2char) - 1
|
||||
|
||||
# padding
|
||||
self.idx2char.append(padding_token)
|
||||
self.padding_idx = len(self.idx2char) - 1
|
||||
|
||||
# update char2idx
|
||||
self.char2idx = {}
|
||||
for idx, char in enumerate(self.idx2char):
|
||||
self.char2idx[char] = idx
|
||||
|
||||
def str2tensor(self, strings):
|
||||
"""
|
||||
Convert text-string into tensor.
|
||||
Args:
|
||||
strings (list[str]): ['hello', 'world']
|
||||
Returns:
|
||||
dict (str: Tensor | list[tensor]):
|
||||
tensors (list[Tensor]): [torch.Tensor([1,2,3,3,4]),
|
||||
torch.Tensor([5,4,6,3,7])]
|
||||
padded_targets (Tensor(bsz * max_seq_len))
|
||||
"""
|
||||
assert utils.is_type_list(strings, str)
|
||||
|
||||
tensors, padded_targets = [], []
|
||||
indexes = self.str2idx(strings)
|
||||
for index in indexes:
|
||||
tensor = torch.LongTensor(index)
|
||||
tensors.append(tensor)
|
||||
# target tensor for loss
|
||||
src_target = torch.LongTensor(tensor.size(0) + 2).fill_(0)
|
||||
src_target[-1] = self.end_idx
|
||||
src_target[0] = self.start_idx
|
||||
src_target[1:-1] = tensor
|
||||
padded_target = (torch.ones(self.max_seq_len) *
|
||||
self.padding_idx).long()
|
||||
char_num = src_target.size(0)
|
||||
if char_num > self.max_seq_len:
|
||||
padded_target = src_target[:self.max_seq_len]
|
||||
else:
|
||||
padded_target[:char_num] = src_target
|
||||
padded_targets.append(padded_target)
|
||||
padded_targets = torch.stack(padded_targets, 0).long()
|
||||
|
||||
return {'targets': tensors, 'padded_targets': padded_targets}
|
||||
|
||||
def tensor2idx(self, outputs, img_metas=None):
|
||||
"""
|
||||
Convert output tensor to text-index
|
||||
Args:
|
||||
outputs (tensor): model outputs with size: N * T * C
|
||||
img_metas (list[dict]): Each dict contains one image info.
|
||||
Returns:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]
|
||||
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
|
||||
[0.9,0.9,0.98,0.97,0.96]]
|
||||
"""
|
||||
batch_size = outputs.size(0)
|
||||
ignore_indexes = [self.padding_idx]
|
||||
indexes, scores = [], []
|
||||
for idx in range(batch_size):
|
||||
seq = outputs[idx, :, :]
|
||||
seq = seq.softmax(dim=-1)
|
||||
max_value, max_idx = torch.max(seq, -1)
|
||||
str_index, str_score = [], []
|
||||
output_index = max_idx.cpu().detach().numpy().tolist()
|
||||
output_score = max_value.cpu().detach().numpy().tolist()
|
||||
for char_index, char_score in zip(output_index, output_score):
|
||||
if char_index in ignore_indexes:
|
||||
continue
|
||||
if char_index == self.end_idx:
|
||||
break
|
||||
str_index.append(char_index)
|
||||
str_score.append(char_score)
|
||||
|
||||
indexes.append(str_index)
|
||||
scores.append(str_score)
|
||||
|
||||
return indexes, scores
|
|
@ -1,128 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseConvertor:
|
||||
"""Convert between text, index and tensor for text recognize pipeline.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, options are 'DICT36', 'DICT37', 'DICT90'
|
||||
and 'DICT91'.
|
||||
dict_file (None|str): Character dict file path. If not none,
|
||||
the dict_file is of higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, the list
|
||||
is of higher priority than dict_type, but lower than dict_file.
|
||||
"""
|
||||
start_idx = end_idx = padding_idx = 0
|
||||
unknown_idx = None
|
||||
lower = False
|
||||
|
||||
dicts = dict(
|
||||
DICT36=tuple('0123456789abcdefghijklmnopqrstuvwxyz'),
|
||||
DICT90=tuple('0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
|
||||
'*+,-./:;<=>?@[\\]_`~'),
|
||||
# With space character
|
||||
DICT37=tuple('0123456789abcdefghijklmnopqrstuvwxyz '),
|
||||
DICT91=tuple('0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
|
||||
'*+,-./:;<=>?@[\\]_`~ '))
|
||||
|
||||
def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None):
|
||||
assert dict_file is None or isinstance(dict_file, str)
|
||||
assert dict_list is None or isinstance(dict_list, list)
|
||||
self.idx2char = []
|
||||
if dict_file is not None:
|
||||
for line_num, line in enumerate(list_from_file(dict_file)):
|
||||
line = line.strip('\r\n')
|
||||
if len(line) > 1:
|
||||
raise ValueError('Expect each line has 0 or 1 character, '
|
||||
f'got {len(line)} characters '
|
||||
f'at line {line_num + 1}')
|
||||
if line != '':
|
||||
self.idx2char.append(line)
|
||||
elif dict_list is not None:
|
||||
self.idx2char = list(dict_list)
|
||||
else:
|
||||
if dict_type in self.dicts:
|
||||
self.idx2char = list(self.dicts[dict_type])
|
||||
else:
|
||||
raise NotImplementedError(f'Dict type {dict_type} is not '
|
||||
'supported')
|
||||
|
||||
assert len(set(self.idx2char)) == len(self.idx2char), \
|
||||
'Invalid dictionary: Has duplicated characters.'
|
||||
|
||||
self.char2idx = {char: idx for idx, char in enumerate(self.idx2char)}
|
||||
|
||||
def num_classes(self):
|
||||
"""Number of output classes."""
|
||||
return len(self.idx2char)
|
||||
|
||||
def str2idx(self, strings):
|
||||
"""Convert strings to indexes.
|
||||
|
||||
Args:
|
||||
strings (list[str]): ['hello', 'world'].
|
||||
Returns:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
|
||||
"""
|
||||
assert isinstance(strings, list)
|
||||
|
||||
indexes = []
|
||||
for string in strings:
|
||||
if self.lower:
|
||||
string = string.lower()
|
||||
index = []
|
||||
for char in string:
|
||||
char_idx = self.char2idx.get(char, self.unknown_idx)
|
||||
if char_idx is None:
|
||||
raise Exception(f'Chararcter: {char} not in dict,'
|
||||
f' please check gt_label and use'
|
||||
f' custom dict file,'
|
||||
f' or set "with_unknown=True"')
|
||||
index.append(char_idx)
|
||||
indexes.append(index)
|
||||
|
||||
return indexes
|
||||
|
||||
def str2tensor(self, strings):
|
||||
"""Convert text-string to input tensor.
|
||||
|
||||
Args:
|
||||
strings (list[str]): ['hello', 'world'].
|
||||
Returns:
|
||||
tensors (list[torch.Tensor]): [torch.Tensor([1,2,3,3,4]),
|
||||
torch.Tensor([5,4,6,3,7])].
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def idx2str(self, indexes):
|
||||
"""Convert indexes to text strings.
|
||||
|
||||
Args:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
|
||||
Returns:
|
||||
strings (list[str]): ['hello', 'world'].
|
||||
"""
|
||||
assert isinstance(indexes, list)
|
||||
|
||||
strings = []
|
||||
for index in indexes:
|
||||
string = [self.idx2char[i] for i in index]
|
||||
strings.append(''.join(string))
|
||||
|
||||
return strings
|
||||
|
||||
def tensor2idx(self, output):
|
||||
"""Convert model output tensor to character indexes and scores.
|
||||
Args:
|
||||
output (tensor): The model outputs with size: N * T * C
|
||||
Returns:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
|
||||
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
|
||||
[0.9,0.9,0.98,0.97,0.96]].
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -1,145 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BaseConvertor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CTCConvertor(BaseConvertor):
|
||||
"""Convert between text, index and tensor for CTC loss-based pipeline.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
|
||||
dict_file (None|str): Character dict file path. If not none, the file
|
||||
is of higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, the list
|
||||
is of higher priority than dict_type, but lower than dict_file.
|
||||
with_unknown (bool): If True, add `UKN` token to class.
|
||||
lower (bool): If True, convert original string to lower case.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dict_type='DICT90',
|
||||
dict_file=None,
|
||||
dict_list=None,
|
||||
with_unknown=True,
|
||||
lower=False,
|
||||
**kwargs):
|
||||
super().__init__(dict_type, dict_file, dict_list)
|
||||
assert isinstance(with_unknown, bool)
|
||||
assert isinstance(lower, bool)
|
||||
|
||||
self.with_unknown = with_unknown
|
||||
self.lower = lower
|
||||
self.update_dict()
|
||||
|
||||
def update_dict(self):
|
||||
# CTC-blank
|
||||
blank_token = '<BLK>'
|
||||
self.blank_idx = 0
|
||||
self.idx2char.insert(0, blank_token)
|
||||
|
||||
# unknown
|
||||
self.unknown_idx = None
|
||||
if self.with_unknown:
|
||||
self.idx2char.append('<UKN>')
|
||||
self.unknown_idx = len(self.idx2char) - 1
|
||||
|
||||
# update char2idx
|
||||
self.char2idx = {}
|
||||
for idx, char in enumerate(self.idx2char):
|
||||
self.char2idx[char] = idx
|
||||
|
||||
def str2tensor(self, strings):
|
||||
"""Convert text-string to ctc-loss input tensor.
|
||||
|
||||
Args:
|
||||
strings (list[str]): ['hello', 'world'].
|
||||
Returns:
|
||||
dict (str: tensor | list[tensor]):
|
||||
tensors (list[tensor]): [torch.Tensor([1,2,3,3,4]),
|
||||
torch.Tensor([5,4,6,3,7])].
|
||||
flatten_targets (tensor): torch.Tensor([1,2,3,3,4,5,4,6,3,7]).
|
||||
target_lengths (tensor): torch.IntTensot([5,5]).
|
||||
"""
|
||||
assert utils.is_type_list(strings, str)
|
||||
|
||||
tensors = []
|
||||
indexes = self.str2idx(strings)
|
||||
for index in indexes:
|
||||
tensor = torch.IntTensor(index)
|
||||
tensors.append(tensor)
|
||||
target_lengths = torch.IntTensor([len(t) for t in tensors])
|
||||
flatten_target = torch.cat(tensors)
|
||||
|
||||
return {
|
||||
'targets': tensors,
|
||||
'flatten_targets': flatten_target,
|
||||
'target_lengths': target_lengths
|
||||
}
|
||||
|
||||
def tensor2idx(self, output, img_metas, topk=1, return_topk=False):
|
||||
"""Convert model output tensor to index-list.
|
||||
Args:
|
||||
output (tensor): The model outputs with size: N * T * C.
|
||||
img_metas (list[dict]): Each dict contains one image info.
|
||||
topk (int): The highest k classes to be returned.
|
||||
return_topk (bool): Whether to return topk or just top1.
|
||||
Returns:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
|
||||
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
|
||||
[0.9,0.9,0.98,0.97,0.96]]
|
||||
(
|
||||
indexes_topk (list[list[list[int]->len=topk]]):
|
||||
scores_topk (list[list[list[float]->len=topk]])
|
||||
).
|
||||
"""
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == output.size(0)
|
||||
assert isinstance(topk, int)
|
||||
assert topk >= 1
|
||||
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
|
||||
batch_size = output.size(0)
|
||||
output = F.softmax(output, dim=2)
|
||||
output = output.cpu().detach()
|
||||
batch_topk_value, batch_topk_idx = output.topk(topk, dim=2)
|
||||
batch_max_idx = batch_topk_idx[:, :, 0]
|
||||
scores_topk, indexes_topk = [], []
|
||||
scores, indexes = [], []
|
||||
feat_len = output.size(1)
|
||||
for b in range(batch_size):
|
||||
valid_ratio = valid_ratios[b]
|
||||
decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
|
||||
pred = batch_max_idx[b, :]
|
||||
select_idx = []
|
||||
prev_idx = self.blank_idx
|
||||
for t in range(decode_len):
|
||||
tmp_value = pred[t].item()
|
||||
if tmp_value not in (prev_idx, self.blank_idx):
|
||||
select_idx.append(t)
|
||||
prev_idx = tmp_value
|
||||
select_idx = torch.LongTensor(select_idx)
|
||||
topk_value = torch.index_select(batch_topk_value[b, :, :], 0,
|
||||
select_idx) # valid_seqlen * topk
|
||||
topk_idx = torch.index_select(batch_topk_idx[b, :, :], 0,
|
||||
select_idx)
|
||||
topk_idx_list, topk_value_list = topk_idx.numpy().tolist(
|
||||
), topk_value.numpy().tolist()
|
||||
indexes_topk.append(topk_idx_list)
|
||||
scores_topk.append(topk_value_list)
|
||||
indexes.append([x[0] for x in topk_idx_list])
|
||||
scores.append([x[0] for x in topk_value_list])
|
||||
|
||||
if return_topk:
|
||||
return indexes_topk, scores_topk
|
||||
|
||||
return indexes, scores
|
|
@ -1,127 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BaseConvertor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegConvertor(BaseConvertor):
|
||||
"""Convert between text, index and tensor for segmentation based pipeline.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
|
||||
dict_file (None|str): Character dict file path. If not none, the
|
||||
file is of higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, the list
|
||||
is of higher priority than dict_type, but lower than dict_file.
|
||||
with_unknown (bool): If True, add `UKN` token to class.
|
||||
lower (bool): If True, convert original string to lower case.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dict_type='DICT36',
|
||||
dict_file=None,
|
||||
dict_list=None,
|
||||
with_unknown=True,
|
||||
lower=False,
|
||||
**kwargs):
|
||||
super().__init__(dict_type, dict_file, dict_list)
|
||||
assert isinstance(with_unknown, bool)
|
||||
assert isinstance(lower, bool)
|
||||
|
||||
self.with_unknown = with_unknown
|
||||
self.lower = lower
|
||||
self.update_dict()
|
||||
|
||||
def update_dict(self):
|
||||
# background
|
||||
self.idx2char.insert(0, '<BG>')
|
||||
|
||||
# unknown
|
||||
self.unknown_idx = None
|
||||
if self.with_unknown:
|
||||
self.idx2char.append('<UKN>')
|
||||
self.unknown_idx = len(self.idx2char) - 1
|
||||
|
||||
# update char2idx
|
||||
self.char2idx = {}
|
||||
for idx, char in enumerate(self.idx2char):
|
||||
self.char2idx[char] = idx
|
||||
|
||||
def tensor2str(self, output, img_metas=None):
|
||||
"""Convert model output tensor to string labels.
|
||||
Args:
|
||||
output (tensor): Model outputs with size: N * C * H * W
|
||||
img_metas (list[dict]): Each dict contains one image info.
|
||||
Returns:
|
||||
texts (list[str]): Decoded text labels.
|
||||
scores (list[list[float]]): Decoded chars scores.
|
||||
"""
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == output.size(0)
|
||||
|
||||
texts, scores = [], []
|
||||
for b in range(output.size(0)):
|
||||
seg_pred = output[b].detach()
|
||||
valid_width = int(
|
||||
output.size(-1) * img_metas[b]['valid_ratio'] + 1)
|
||||
seg_res = torch.argmax(
|
||||
seg_pred[:, :, :valid_width],
|
||||
dim=0).cpu().numpy().astype(np.int32)
|
||||
|
||||
seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8)
|
||||
_, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
||||
seg_thr)
|
||||
|
||||
component_num = stats.shape[0]
|
||||
|
||||
all_res = []
|
||||
for i in range(component_num):
|
||||
temp_loc = (labels == i)
|
||||
temp_value = seg_res[temp_loc]
|
||||
temp_center = centroids[i]
|
||||
|
||||
temp_max_num = 0
|
||||
temp_max_cls = -1
|
||||
temp_total_num = 0
|
||||
for c in range(len(self.idx2char)):
|
||||
c_num = np.sum(temp_value == c)
|
||||
temp_total_num += c_num
|
||||
if c_num > temp_max_num:
|
||||
temp_max_num = c_num
|
||||
temp_max_cls = c
|
||||
|
||||
if temp_max_cls == 0:
|
||||
continue
|
||||
temp_max_score = 1.0 * temp_max_num / temp_total_num
|
||||
all_res.append(
|
||||
[temp_max_cls, temp_center, temp_max_num, temp_max_score])
|
||||
|
||||
all_res = sorted(all_res, key=lambda s: s[1][0])
|
||||
chars, char_scores = [], []
|
||||
for res in all_res:
|
||||
temp_area = res[2]
|
||||
if temp_area < 20:
|
||||
continue
|
||||
temp_char_index = res[0]
|
||||
if temp_char_index >= len(self.idx2char):
|
||||
temp_char = ''
|
||||
elif temp_char_index <= 0:
|
||||
temp_char = ''
|
||||
elif temp_char_index == self.unknown_idx:
|
||||
temp_char = ''
|
||||
else:
|
||||
temp_char = self.idx2char[temp_char_index]
|
||||
chars.append(temp_char)
|
||||
char_scores.append(res[3])
|
||||
|
||||
text = ''.join(chars)
|
||||
|
||||
texts.append(text)
|
||||
scores.append(char_scores)
|
||||
|
||||
return texts, scores
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .seg_head import SegHead
|
||||
|
||||
__all__ = ['SegHead']
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegHead(BaseModule):
|
||||
"""Head for segmentation based text recognition.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels :math:`C`.
|
||||
num_classes (int): Number of output classes :math:`C_{out}`.
|
||||
upsample_param (dict | None): Config dict for interpolation layer.
|
||||
Default: ``dict(scale_factor=1.0, mode='nearest')``
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
num_classes=37,
|
||||
upsample_param=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(num_classes, int)
|
||||
assert num_classes > 0
|
||||
assert upsample_param is None or isinstance(upsample_param, dict)
|
||||
|
||||
self.upsample_param = upsample_param
|
||||
|
||||
self.seg_conv = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
# prediction
|
||||
self.pred_conv = nn.Conv2d(
|
||||
in_channels, num_classes, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, out_neck):
|
||||
"""
|
||||
Args:
|
||||
out_neck (list[Tensor]): A list of tensor of shape
|
||||
:math:`(N, C_i, H_i, W_i)`. The network only uses the last one
|
||||
(``out_neck[-1]``).
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of shape :math:`(N, C_{out}, kH, kW)` where
|
||||
:math:`k` is determined by ``upsample_param``.
|
||||
"""
|
||||
|
||||
seg_map = self.seg_conv(out_neck[-1])
|
||||
seg_map = self.pred_conv(seg_map)
|
||||
|
||||
if self.upsample_param is not None:
|
||||
seg_map = F.interpolate(seg_map, **self.upsample_param)
|
||||
|
||||
return seg_map
|
|
@ -8,9 +8,8 @@ from .nrtr import NRTR
|
|||
from .robust_scanner import RobustScanner
|
||||
from .sar import SARNet
|
||||
from .satrn import SATRN
|
||||
from .seg_recognizer import SegRecognizer
|
||||
|
||||
__all__ = [
|
||||
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR',
|
||||
'SegRecognizer', 'RobustScanner', 'SATRN', 'ABINet', 'MASTER'
|
||||
'RobustScanner', 'SATRN', 'ABINet', 'MASTER'
|
||||
]
|
||||
|
|
|
@ -1,148 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BaseRecognizer
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegRecognizer(BaseRecognizer):
|
||||
"""Base class for segmentation based recognizer."""
|
||||
|
||||
def __init__(self,
|
||||
preprocessor=None,
|
||||
backbone=None,
|
||||
neck=None,
|
||||
head=None,
|
||||
loss=None,
|
||||
label_convertor=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
# Label_convertor
|
||||
assert label_convertor is not None
|
||||
self.label_convertor = MODELS.build(label_convertor)
|
||||
|
||||
# Preprocessor module, e.g., TPS
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = MODELS.build(preprocessor)
|
||||
|
||||
# Backbone
|
||||
assert backbone is not None
|
||||
self.backbone = MODELS.build(backbone)
|
||||
|
||||
# Neck
|
||||
assert neck is not None
|
||||
self.neck = MODELS.build(neck)
|
||||
|
||||
# Head
|
||||
assert head is not None
|
||||
head.update(num_classes=self.label_convertor.num_classes())
|
||||
self.head = MODELS.build(head)
|
||||
|
||||
# Loss
|
||||
assert loss is not None
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
if pretrained is not None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated \
|
||||
key, please consider using init_cfg')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Directly extract features from the backbone."""
|
||||
if self.preprocessor is not None:
|
||||
img = self.preprocessor(img)
|
||||
|
||||
x = self.backbone(img)
|
||||
|
||||
return x
|
||||
|
||||
def forward_train(self, img, img_metas, gt_kernels=None):
|
||||
"""
|
||||
Args:
|
||||
img (tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A list of image info dict where each dict
|
||||
contains: 'img_shape', 'filename', and may also contain
|
||||
'ori_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
|
||||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
"""
|
||||
|
||||
feats = self.extract_feat(img)
|
||||
|
||||
out_neck = self.neck(feats)
|
||||
|
||||
out_head = self.head(out_neck)
|
||||
|
||||
loss_inputs = (out_neck, out_head, gt_kernels)
|
||||
|
||||
losses = self.loss(*loss_inputs)
|
||||
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas, **kwargs):
|
||||
"""Test function without test time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
"""
|
||||
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
out_neck = self.neck(feat)
|
||||
|
||||
out_head = self.head(out_neck)
|
||||
|
||||
for img_meta in img_metas:
|
||||
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
|
||||
img_meta['valid_ratio'] = valid_ratio
|
||||
|
||||
texts, scores = self.label_convertor.tensor2str(out_head, img_metas)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for text, score in zip(texts, scores):
|
||||
results.append(dict(text=text, score=score))
|
||||
|
||||
return results
|
||||
|
||||
def merge_aug_results(self, aug_results):
|
||||
out_text, out_score = '', -1
|
||||
for result in aug_results:
|
||||
text = result[0]['text']
|
||||
score = sum(result[0]['score']) / max(1, len(text))
|
||||
if score > out_score:
|
||||
out_text = text
|
||||
out_score = score
|
||||
out_results = [dict(text=out_text, score=out_score)]
|
||||
return out_results
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
"""Test function with test time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (list[tensor]): Tensor should have shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
img_metas (list[list[dict]]): The metadata of images.
|
||||
"""
|
||||
aug_results = []
|
||||
for img, img_meta in zip(imgs, img_metas):
|
||||
result = self.simple_test(img, img_meta, **kwargs)
|
||||
aug_results.append(result)
|
||||
|
||||
return self.merge_aug_results(aug_results)
|
Loading…
Reference in New Issue