From 59cc4efdc5552123d707a7ba84e0dee48c373aa5 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 22 Jul 2021 19:58:14 +0800 Subject: [PATCH 1/3] add for SEED --- configs/rec/rec_resnet_stn_bilstm_att.yml | 101 +++ ppocr/data/imaug/label_ops.py | 64 +- ppocr/data/simple_dataset.py | 1 + ppocr/losses/__init__.py | 5 +- ppocr/losses/rec_aster_loss.py | 79 ++ ppocr/losses/rec_att_loss.py | 2 + ppocr/modeling/backbones/__init__.py | 4 +- ppocr/modeling/backbones/levit.py | 707 ++++++++++++++++++ ppocr/modeling/backbones/rec_resnet_aster.py | 147 ++++ ppocr/modeling/heads/__init__.py | 6 +- ppocr/modeling/heads/rec_aster_head.py | 258 +++++++ ppocr/modeling/heads/rec_att_head.py | 5 + ppocr/modeling/transforms/__init__.py | 3 +- ppocr/modeling/transforms/stn.py | 121 +++ ppocr/modeling/transforms/tps.py | 29 +- .../transforms/tps_spatial_transformer.py | 178 +++++ ppocr/modeling/transforms/tps_torch.py | 149 ++++ ppocr/postprocess/rec_postprocess.py | 29 +- ppocr/utils/save_load.py | 17 +- tools/program.py | 15 +- tools/train.py | 2 + 21 files changed, 1868 insertions(+), 54 deletions(-) create mode 100644 configs/rec/rec_resnet_stn_bilstm_att.yml create mode 100644 ppocr/losses/rec_aster_loss.py create mode 100644 ppocr/modeling/backbones/levit.py create mode 100644 ppocr/modeling/backbones/rec_resnet_aster.py create mode 100644 ppocr/modeling/heads/rec_aster_head.py create mode 100644 ppocr/modeling/transforms/stn.py create mode 100644 ppocr/modeling/transforms/tps_spatial_transformer.py create mode 100644 ppocr/modeling/transforms/tps_torch.py diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml new file mode 100644 index 000000000..f705f1e23 --- /dev/null +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -0,0 +1,101 @@ +Global: + use_gpu: False + epoch_num: 400 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/b3_rare_r34_none_gru/ + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + character_type: EN_symbol + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0005 + regularizer: + name: 'L2' + factor: 0.00000 + +Architecture: + model_type: rec + algorithm: ASTER + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [32, 100] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: ResNet_ASTER + Head: + name: AsterHead # AttentionHead + sDim: 512 + attDim: 512 + max_len_labels: 100 + +Loss: + name: AsterLoss + +PostProcess: + name: AttnLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/1.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 2 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/1.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 2 + num_workers: 8 diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index e25cce79b..0e1d4939d 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -104,6 +104,7 @@ class BaseRecLabelEncode(object): self.max_text_len = max_text_length self.beg_str = "sos" self.end_str = "eos" + self.unknown = "UNKNOWN" if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -275,7 +276,9 @@ class AttnLabelEncode(BaseRecLabelEncode): def add_special_char(self, dict_character): self.beg_str = "sos" self.end_str = "eos" - dict_character = [self.beg_str] + dict_character + [self.end_str] + self.unknown = "UNKNOWN" + dict_character = [self.beg_str] + dict_character + [self.end_str + ] + [self.unknown] return dict_character def __call__(self, data): @@ -288,6 +291,7 @@ class AttnLabelEncode(BaseRecLabelEncode): data['length'] = np.array(len(text)) text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len - len(text) - 2) + data['label'] = np.array(text) return data @@ -352,19 +356,22 @@ class SRNLabelEncode(BaseRecLabelEncode): % beg_or_end return idx + class TableLabelEncode(object): """ Convert between text-label and text-index """ - def __init__(self, - max_text_length, - max_elem_length, - max_cell_num, - character_dict_path, - span_weight = 1.0, - **kwargs): + + def __init__(self, + max_text_length, + max_elem_length, + max_cell_num, + character_dict_path, + span_weight=1.0, + **kwargs): self.max_text_length = max_text_length self.max_elem_length = max_elem_length self.max_cell_num = max_cell_num - list_character, list_elem = self.load_char_elem_dict(character_dict_path) + list_character, list_elem = self.load_char_elem_dict( + character_dict_path) list_character = self.add_special_char(list_character) list_elem = self.add_special_char(list_elem) self.dict_character = {} @@ -374,7 +381,7 @@ class TableLabelEncode(object): for i, elem in enumerate(list_elem): self.dict_elem[elem] = i self.span_weight = span_weight - + def load_char_elem_dict(self, character_dict_path): list_character = [] list_elem = [] @@ -383,27 +390,27 @@ class TableLabelEncode(object): substr = lines[0].decode('utf-8').strip("\n").split("\t") character_num = int(substr[0]) elem_num = int(substr[1]) - for cno in range(1, 1+character_num): + for cno in range(1, 1 + character_num): character = lines[cno].decode('utf-8').strip("\n") list_character.append(character) - for eno in range(1+character_num, 1+character_num+elem_num): + for eno in range(1 + character_num, 1 + character_num + elem_num): elem = lines[eno].decode('utf-8').strip("\n") list_elem.append(elem) return list_character, list_elem - + def add_special_char(self, list_character): self.beg_str = "sos" self.end_str = "eos" list_character = [self.beg_str] + list_character + [self.end_str] return list_character - + def get_span_idx_list(self): span_idx_list = [] for elem in self.dict_elem: if 'span' in elem: span_idx_list.append(self.dict_elem[elem]) return span_idx_list - + def __call__(self, data): cells = data['cells'] structure = data['structure']['tokens'] @@ -412,18 +419,22 @@ class TableLabelEncode(object): return None elem_num = len(structure) structure = [0] + structure + [len(self.dict_elem) - 1] - structure = structure + [0] * (self.max_elem_length + 2 - len(structure)) + structure = structure + [0] * (self.max_elem_length + 2 - len(structure) + ) structure = np.array(structure) data['structure'] = structure elem_char_idx1 = self.dict_elem[''] elem_char_idx2 = self.dict_elem[' 0: span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) @@ -450,9 +461,11 @@ class TableLabelEncode(object): char_end_idx = self.get_beg_end_flag_idx('end', 'char') elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') - data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx, - elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length, - self.max_elem_length, self.max_cell_num, elem_num]) + data['sp_tokens'] = np.array([ + char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx, + elem_char_idx1, elem_char_idx2, self.max_text_length, + self.max_elem_length, self.max_cell_num, elem_num + ]) return data def encode(self, text, char_or_elem): @@ -504,9 +517,8 @@ class TableLabelEncode(object): idx = np.array(self.dict_elem[self.end_str]) else: assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ - % beg_or_end + % beg_or_end else: assert False, "Unsupport type %s in char_or_elem" \ - % char_or_elem + % char_or_elem return idx - \ No newline at end of file diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index ce9e1b386..b519f4fde 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -22,6 +22,7 @@ from .imaug import transform, create_operators class SimpleDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): + print("===== simpledataset ========") super(SimpleDataSet, self).__init__() self.logger = logger self.mode = mode.lower() diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 025ae7ca5..2a6737745 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -41,10 +41,13 @@ from .combined_loss import CombinedLoss # table loss from .table_att_loss import TableAttentionLoss +from .rec_aster_loss import AsterLoss + + def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss' + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss', 'AsterLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_aster_loss.py b/ppocr/losses/rec_aster_loss.py new file mode 100644 index 000000000..858fadc02 --- /dev/null +++ b/ppocr/losses/rec_aster_loss.py @@ -0,0 +1,79 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import fasttext + + +class AsterLoss(nn.Layer): + def __init__(self, + weight=None, + size_average=True, + ignore_index=-100, + sequence_normalize=False, + sample_normalize=True, + **kwargs): + super(AsterLoss, self).__init__() + self.weight = weight + self.size_average = size_average + self.ignore_index = ignore_index + self.sequence_normalize = sequence_normalize + self.sample_normalize = sample_normalize + self.loss_func = paddle.nn.CosineSimilarity() + + def forward(self, predicts, batch): + targets = batch[1].astype("int64") + label_lengths = batch[2].astype('int64') + # sem_target = batch[3].astype('float32') + embedding_vectors = predicts['embedding_vectors'] + rec_pred = predicts['rec_pred'] + + # semantic loss + # print(embedding_vectors) + # print(embedding_vectors.shape) + # targets = fasttext[targets] + # sem_loss = 1 - self.loss_func(embedding_vectors, targets) + + # rec loss + batch_size, num_steps, num_classes = rec_pred.shape[0], rec_pred.shape[ + 1], rec_pred.shape[2] + assert len(targets.shape) == len(list(rec_pred.shape)) - 1, \ + "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + + mask = paddle.zeros([batch_size, num_steps]) + for i in range(batch_size): + mask[i, :label_lengths[i]] = 1 + mask = paddle.cast(mask, "float32") + max_length = max(label_lengths) + assert max_length == rec_pred.shape[1] + targets = targets[:, :max_length] + mask = mask[:, :max_length] + rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[-1]]) + input = nn.functional.log_softmax(rec_pred, axis=1) + targets = paddle.reshape(targets, [-1, 1]) + mask = paddle.reshape(mask, [-1, 1]) + # print("input:", input) + output = -paddle.gather(input, index=targets, axis=1) * mask + output = paddle.sum(output) + if self.sequence_normalize: + output = output / paddle.sum(mask) + if self.sample_normalize: + output = output / batch_size + loss = output + return {'loss': loss} # , 'sem_loss':sem_loss} diff --git a/ppocr/losses/rec_att_loss.py b/ppocr/losses/rec_att_loss.py index 6e2f67483..2d8d64b9d 100644 --- a/ppocr/losses/rec_att_loss.py +++ b/ppocr/losses/rec_att_loss.py @@ -35,5 +35,7 @@ class AttentionLoss(nn.Layer): inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]]) targets = paddle.reshape(targets, [-1]) + print("input:", paddle.argmax(inputs, axis=1)) + print("targets:", targets) return {'loss': paddle.sum(self.loss_func(inputs, targets))} diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index f4fe8c76b..e0bc45b47 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -26,8 +26,10 @@ def build_backbone(config, model_type): from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN from .rec_mv1_enhance import MobileNetV1Enhance + from .rec_resnet_aster import ResNet_ASTER support_dict = [ - "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN" + "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN", + "ResNet_ASTER" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/levit.py b/ppocr/modeling/backbones/levit.py new file mode 100644 index 000000000..8b04e9def --- /dev/null +++ b/ppocr/modeling/backbones/levit.py @@ -0,0 +1,707 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License + +import paddle +import itertools +#import utils +import math +import warnings +import paddle.nn.functional as F +from paddle.nn.initializer import TruncatedNormal, Constant + +#from timm.models.vision_transformer import trunc_normal_ +#from timm.models.registry import register_model + +specification = { + 'LeViT_128S': { + 'C': '128_256_384', + 'D': 16, + 'N': '4_6_8', + 'X': '2_3_4', + 'drop_path': 0, + 'weights': + 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' + }, + 'LeViT_128': { + 'C': '128_256_384', + 'D': 16, + 'N': '4_8_12', + 'X': '4_4_4', + 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' + }, + 'LeViT_192': { + 'C': '192_288_384', + 'D': 32, + 'N': '3_5_6', + 'X': '4_4_4', + 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' + }, + 'LeViT_256': { + 'C': '256_384_512', + 'D': 32, + 'N': '4_6_8', + 'X': '4_4_4', + 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' + }, + 'LeViT_384': { + 'C': '384_512_768', + 'D': 32, + 'N': '6_9_12', + 'X': '4_4_4', + 'drop_path': 0.1, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' + }, +} + +__all__ = [specification.keys()] + +trunc_normal_ = TruncatedNormal(std=.02) +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +#@register_model +def LeViT_128S(class_dim=1000, distillation=True, pretrained=False, fuse=False): + return model_factory( + **specification['LeViT_128S'], + class_dim=class_dim, + distillation=distillation, + pretrained=pretrained, + fuse=fuse) + + +#@register_model +def LeViT_128(class_dim=1000, distillation=True, pretrained=False, fuse=False): + return model_factory( + **specification['LeViT_128'], + class_dim=class_dim, + distillation=distillation, + pretrained=pretrained, + fuse=fuse) + + +#@register_model +def LeViT_192(class_dim=1000, distillation=True, pretrained=False, fuse=False): + return model_factory( + **specification['LeViT_192'], + class_dim=class_dim, + distillation=distillation, + pretrained=pretrained, + fuse=fuse) + + +#@register_model +def LeViT_256(class_dim=1000, distillation=False, pretrained=False, fuse=False): + return model_factory( + **specification['LeViT_256'], + class_dim=class_dim, + distillation=distillation, + pretrained=pretrained, + fuse=fuse) + + +#@register_model +def LeViT_384(class_dim=1000, distillation=True, pretrained=False, fuse=False): + return model_factory( + **specification['LeViT_384'], + class_dim=class_dim, + distillation=distillation, + pretrained=pretrained, + fuse=fuse) + + +FLOPS_COUNTER = 0 + + +class Conv2d_BN(paddle.nn.Sequential): + def __init__(self, + a, + b, + ks=1, + stride=1, + pad=0, + dilation=1, + groups=1, + bn_weight_init=1, + resolution=-10000): + super().__init__() + self.add_sublayer( + 'c', + paddle.nn.Conv2D( + a, b, ks, stride, pad, dilation, groups, bias_attr=False)) + bn = paddle.nn.BatchNorm2D(b) + ones_(bn.weight) + zeros_(bn.bias) + self.add_sublayer('bn', bn) + + global FLOPS_COUNTER + output_points = ( + (resolution + 2 * pad - dilation * (ks - 1) - 1) // stride + 1)**2 + FLOPS_COUNTER += a * b * output_points * (ks**2) + + @paddle.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = paddle.nn.Conv2D( + w.size(1), + w.size(0), + w.shape[2:], + stride=self.c.stride, + padding=self.c.padding, + dilation=self.c.dilation, + groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class Linear_BN(paddle.nn.Sequential): + def __init__(self, a, b, bn_weight_init=1, resolution=-100000): + super().__init__() + self.add_sublayer('c', paddle.nn.Linear(a, b, bias_attr=False)) + bn = paddle.nn.BatchNorm1D(b) + ones_(bn.weight) + zeros_(bn.bias) + self.add_sublayer('bn', bn) + + global FLOPS_COUNTER + output_points = resolution**2 + FLOPS_COUNTER += a * b * output_points + + @paddle.no_grad() + def fuse(self): + l, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = l.weight * w[:, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = paddle.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + def forward(self, x): + l, bn = self._sub_layers.values() + x = l(x) + return paddle.reshape(bn(x.flatten(0, 1)), x.shape) + + +class BN_Linear(paddle.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_sublayer('bn', paddle.nn.BatchNorm1D(a)) + l = paddle.nn.Linear(a, b, bias_attr=bias) + trunc_normal_(l.weight) + if bias: + zeros_(l.bias) + self.add_sublayer('l', l) + global FLOPS_COUNTER + FLOPS_COUNTER += a * b + + @paddle.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps)**0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @self.l.weight.T + else: + b = (l.weight @b[:, None]).view(-1) + self.l.bias + m = paddle.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def b16(n, activation, resolution=224): + return paddle.nn.Sequential( + Conv2d_BN( + 3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + Conv2d_BN( + n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + Conv2d_BN( + n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + Conv2d_BN( + n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(paddle.nn.Layer): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * paddle.rand( + x.size(0), 1, 1, + device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(paddle.nn.Layer): + def __init__(self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + activation=None, + resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + self.h = self.dh + nh_kd * 2 + self.qkv = Linear_BN(dim, self.h, resolution=resolution) + self.proj = paddle.nn.Sequential( + activation(), + Linear_BN( + self.dh, dim, bn_weight_init=0, resolution=resolution)) + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = self.create_parameter( + shape=(num_heads, len(attention_offsets)), + default_initializer=zeros_) + tensor_idxs = paddle.to_tensor(idxs, dtype='int64') + self.register_buffer('attention_bias_idxs', + paddle.reshape(tensor_idxs, [N, N])) + + global FLOPS_COUNTER + #queries * keys + FLOPS_COUNTER += num_heads * (resolution**4) * key_dim + # softmax + FLOPS_COUNTER += num_heads * (resolution**4) + #attention * v + FLOPS_COUNTER += num_heads * self.d * (resolution**4) + + @paddle.no_grad() + def train(self, mode=True): + if mode: + super().train() + else: + super().eval() + if mode and hasattr(self, 'ab'): + del self.ab + else: + gather_list = [] + attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) + for idx in self.attention_bias_idxs: + gather = paddle.gather(attention_bias_t, idx) + gather_list.append(gather) + attention_biases = paddle.transpose( + paddle.concat(gather_list), (1, 0)).reshape( + (0, self.attention_bias_idxs.shape[0], + self.attention_bias_idxs.shape[1])) + self.ab = attention_biases + #self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + self.training = True + B, N, C = x.shape + qkv = self.qkv(x) + qkv = paddle.reshape(qkv, + [B, N, self.num_heads, self.h // self.num_heads]) + q, k, v = paddle.split( + qkv, [self.key_dim, self.key_dim, self.d], axis=3) + q = paddle.transpose(q, perm=[0, 2, 1, 3]) + k = paddle.transpose(k, perm=[0, 2, 1, 3]) + v = paddle.transpose(v, perm=[0, 2, 1, 3]) + k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2]) + + if self.training: + gather_list = [] + attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) + for idx in self.attention_bias_idxs: + gather = paddle.gather(attention_bias_t, idx) + gather_list.append(gather) + attention_biases = paddle.transpose( + paddle.concat(gather_list), (1, 0)).reshape( + (0, self.attention_bias_idxs.shape[0], + self.attention_bias_idxs.shape[1])) + else: + attention_biases = self.ab + #np_ = paddle.to_tensor(self.attention_biases.numpy()[:, self.attention_bias_idxs.numpy()]) + #print(self.attention_bias_idxs.shape) + #print(attention_biases.shape) + #print(np_.shape) + #print(np_.equal(attention_biases)) + #exit() + + attn = ((q @k_transpose) * self.scale + attention_biases) + attn = F.softmax(attn) + x = paddle.transpose(attn @v, perm=[0, 2, 1, 3]) + x = paddle.reshape(x, [B, N, self.dh]) + x = self.proj(x) + return x + + +class Subsample(paddle.nn.Layer): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = paddle.reshape(x, [B, self.resolution, self.resolution, + C])[:, ::self.stride, ::self.stride] + x = paddle.reshape(x, [B, -1, C]) + return x + + +class AttentionSubsample(paddle.nn.Layer): + def __init__(self, + in_dim, + out_dim, + key_dim, + num_heads=8, + attn_ratio=2, + activation=None, + stride=2, + resolution=14, + resolution_=7): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_**2 + self.training = True + h = self.dh + nh_kd + self.kv = Linear_BN(in_dim, h, resolution=resolution) + + self.q = paddle.nn.Sequential( + Subsample(stride, resolution), + Linear_BN( + in_dim, nh_kd, resolution=resolution_)) + self.proj = paddle.nn.Sequential( + activation(), Linear_BN( + self.dh, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list( + itertools.product(range(resolution_), range(resolution_))) + + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + i = 0 + j = 0 + for p1 in points_: + i += 1 + for p2 in points: + j += 1 + size = 1 + offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = self.create_parameter( + shape=(num_heads, len(attention_offsets)), + default_initializer=zeros_) + + tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64') + self.register_buffer('attention_bias_idxs', + paddle.reshape(tensor_idxs_, [N_, N])) + + global FLOPS_COUNTER + #queries * keys + FLOPS_COUNTER += num_heads * \ + (resolution**2) * (resolution_**2) * key_dim + # softmax + FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2) + #attention * v + FLOPS_COUNTER += num_heads * \ + (resolution**2) * (resolution_**2) * self.d + + @paddle.no_grad() + def train(self, mode=True): + if mode: + super().train() + else: + super().eval() + if mode and hasattr(self, 'ab'): + del self.ab + else: + gather_list = [] + attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) + for idx in self.attention_bias_idxs: + gather = paddle.gather(attention_bias_t, idx) + gather_list.append(gather) + attention_biases = paddle.transpose( + paddle.concat(gather_list), (1, 0)).reshape( + (0, self.attention_bias_idxs.shape[0], + self.attention_bias_idxs.shape[1])) + self.ab = attention_biases + #self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + self.training = True + B, N, C = x.shape + kv = self.kv(x) + kv = paddle.reshape(kv, [B, N, self.num_heads, -1]) + k, v = paddle.split(kv, [self.key_dim, self.d], axis=3) + k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC + v = paddle.transpose(v, perm=[0, 2, 1, 3]) + q = paddle.reshape( + self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim]) + q = paddle.transpose(q, perm=[0, 2, 1, 3]) + + if self.training: + gather_list = [] + attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) + for idx in self.attention_bias_idxs: + gather = paddle.gather(attention_bias_t, idx) + gather_list.append(gather) + attention_biases = paddle.transpose( + paddle.concat(gather_list), (1, 0)).reshape( + (0, self.attention_bias_idxs.shape[0], + self.attention_bias_idxs.shape[1])) + else: + attention_biases = self.ab + + attn = (q @paddle.transpose( + k, perm=[0, 1, 3, 2])) * self.scale + attention_biases + attn = F.softmax(attn) + + x = paddle.reshape( + paddle.transpose( + (attn @v), perm=[0, 2, 1, 3]), [B, -1, self.dh]) + x = self.proj(x) + return x + + +class LeViT(paddle.nn.Layer): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + class_dim=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attention_activation=paddle.nn.Hardswish, + mlp_activation=paddle.nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + global FLOPS_COUNTER + + self.class_dim = class_dim + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, + down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention( + ed, + kd, + nh, + attn_ratio=ar, + activation=attention_activation, + resolution=resolution, ), + drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual( + paddle.nn.Sequential( + Linear_BN( + ed, h, resolution=resolution), + mlp_activation(), + Linear_BN( + h, + ed, + bn_weight_init=0, + resolution=resolution), ), + drop_path)) + if do[0] == 'Subsample': + #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], + key_dim=do[1], + num_heads=do[2], + attn_ratio=do[3], + activation=attention_activation, + stride=do[5], + resolution=resolution, + resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual( + paddle.nn.Sequential( + Linear_BN( + embed_dim[i + 1], h, resolution=resolution), + mlp_activation(), + Linear_BN( + h, + embed_dim[i + 1], + bn_weight_init=0, + resolution=resolution), ), + drop_path)) + self.blocks = paddle.nn.Sequential(*self.blocks) + + # Classifier head + self.head = BN_Linear( + embed_dim[-1], class_dim) if class_dim > 0 else paddle.nn.Identity() + if distillation: + self.head_dist = BN_Linear( + embed_dim[-1], + class_dim) if class_dim > 0 else paddle.nn.Identity() + + self.FLOPS = FLOPS_COUNTER + FLOPS_COUNTER = 0 + + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + x = x.flatten(2) + x = paddle.transpose(x, perm=[0, 2, 1]) + x = self.blocks(x) + x = x.mean(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, weights, class_dim, distillation, + pretrained, fuse): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = paddle.nn.Hardswish + model = LeViT( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attention_activation=act, + mlp_activation=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + class_dim=class_dim, + drop_path=drop_path, + distillation=distillation) + # if pretrained: + # checkpoint = torch.hub.load_state_dict_from_url( + # weights, map_location='cpu') + # model.load_state_dict(checkpoint['model']) + if fuse: + utils.replace_batchnorm(model) + + return model + + +if __name__ == '__main__': + ''' + import torch + checkpoint = torch.load('../LeViT/pretrained256.pth') + torch_dict = checkpoint['net'] + paddle_dict = {} + fc_names = ["c.weight", "l.weight", "qkv.weight", "fc1.weight", "fc2.weight", "downsample.reduction.weight", "head.weight", "attn.proj.weight"] + rename_dict = {"running_mean": "_mean", "running_var": "_variance"} + range_tuple = (0, 502) + idx = 0 + for key in torch_dict: + idx += 1 + weight = torch_dict[key].cpu().numpy() + flag = [i in key for i in fc_names] + if any(flag): + if "emb" not in key: + print("weight {} need to be trans".format(key)) + weight = weight.transpose() + key = key.replace("running_mean", "_mean") + key = key.replace("running_var", "_variance") + paddle_dict[key]=weight + ''' + import numpy as np + net = globals()['LeViT_256'](fuse=False, + pretrained=False, + distillation=False) + load_layer_state_dict = paddle.load( + "./LeViT_256_official_nodistillation_paddle.pdparams") + #net.set_state_dict(paddle_dict) + net.set_state_dict(load_layer_state_dict) + net.eval() + #paddle.save(net.state_dict(), "./LeViT_256_official_paddle.pdparams") + #model = paddle.jit.to_static(net,input_spec=[paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype='float32')]) + #paddle.jit.save(model, "./LeViT_256_official_inference/inference") + #exit() + np.random.seed(123) + img = np.random.rand(1, 3, 224, 224).astype('float32') + img = paddle.to_tensor(img) + outputs = net(img).numpy() + print(outputs[0][:10]) + #print(outputs.shape) diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py new file mode 100644 index 000000000..5bb580357 --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_aster.py @@ -0,0 +1,147 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +import sys +import math + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2D( + in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False) + + +def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000): + # [n_position] + positions = paddle.arange(0, n_position) + # [feat_dim] + dim_range = paddle.arange(0, feat_dim) + dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim) + # [n_position, feat_dim] + angles = paddle.unsqueeze( + positions, axis=1) / paddle.unsqueeze( + dim_range, axis=0) + angles = paddle.cast(angles, "float32") + angles[:, 0::2] = paddle.sin(angles[:, 0::2]) + angles[:, 1::2] = paddle.cos(angles[:, 1::2]) + return angles + + +class AsterBlock(nn.Layer): + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(AsterBlock, self).__init__() + self.conv1 = conv1x1(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2D(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2D(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet_ASTER(nn.Layer): + """For aster or crnn""" + + def __init__(self, with_lstm=True, n_group=1, in_channels=3): + super(ResNet_ASTER, self).__init__() + self.with_lstm = with_lstm + self.n_group = n_group + + self.layer0 = nn.Sequential( + nn.Conv2D( + in_channels, + 32, + kernel_size=(3, 3), + stride=1, + padding=1, + bias_attr=False), + nn.BatchNorm2D(32), + nn.ReLU()) + + self.inplanes = 32 + self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50] + self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25] + self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25] + self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25] + self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25] + + if with_lstm: + self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2) + self.out_channels = 2 * 256 + else: + self.out_channels = 512 + + def _make_layer(self, planes, blocks, stride): + downsample = None + if stride != [1, 1] or self.inplanes != planes: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes)) + + layers = [] + layers.append(AsterBlock(self.inplanes, planes, stride, downsample)) + self.inplanes = planes + for _ in range(1, blocks): + layers.append(AsterBlock(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + x0 = self.layer0(x) + x1 = self.layer1(x0) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + x5 = self.layer5(x4) + + cnn_feat = x5.squeeze(2) # [N, c, w] + cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1]) + if self.with_lstm: + rnn_feat, _ = self.rnn(cnn_feat) + return rnn_feat + else: + return cnn_feat + + +if __name__ == "__main__": + x = paddle.randn([3, 3, 32, 100]) + net = ResNet_ASTER() + encoder_feat = net(x) + print(encoder_feat.shape) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 509647941..cd923d78b 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -26,12 +26,15 @@ def build_head(config): from .rec_ctc_head import CTCHead from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead + from .rec_aster_head import AttentionRecognitionHead, AsterHead # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead', 'TableAttentionHead'] + 'SRNHead', 'PGHead', 'TableAttentionHead', 'AttentionRecognitionHead', + 'AsterHead' + ] #table head from .table_att_head import TableAttentionHead @@ -39,5 +42,6 @@ def build_head(config): module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( support_dict)) + print(config) module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py new file mode 100644 index 000000000..055b10973 --- /dev/null +++ b/ppocr/modeling/heads/rec_aster_head.py @@ -0,0 +1,258 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import paddle +from paddle import nn +from paddle.nn import functional as F + + +class AsterHead(nn.Layer): + def __init__(self, + in_channels, + out_channels, + sDim, + attDim, + max_len_labels, + time_step=25, + beam_width=5, + **kwargs): + super(AsterHead, self).__init__() + self.num_classes = out_channels + self.in_planes = in_channels + self.sDim = sDim + self.attDim = attDim + self.max_len_labels = max_len_labels + self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim, + attDim, max_len_labels) + self.time_step = time_step + self.embeder = Embedding(self.time_step, in_channels) + self.beam_width = beam_width + + def forward(self, x, targets=None, embed=None): + return_dict = {} + embedding_vectors = self.embeder(x) + rec_targets, rec_lengths = targets + + if self.training: + rec_pred = self.decoder([x, rec_targets, rec_lengths], + embedding_vectors) + return_dict['rec_pred'] = rec_pred + return_dict['embedding_vectors'] = embedding_vectors + else: + rec_pred, rec_pred_scores = self.decoder.beam_search( + x, self.beam_width, self.eos, embedding_vectors) + return_dict['rec_pred'] = rec_pred + return_dict['rec_pred_scores'] = rec_pred_scores + return_dict['embedding_vectors'] = embedding_vectors + + return return_dict + + +class Embedding(nn.Layer): + def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300): + super(Embedding, self).__init__() + self.in_timestep = in_timestep + self.in_planes = in_planes + self.embed_dim = embed_dim + self.mid_dim = mid_dim + self.eEmbed = nn.Linear( + in_timestep * in_planes, + self.embed_dim) # Embed encoder output to a word-embedding like + + def forward(self, x): + x = paddle.reshape(x, [paddle.shape(x)[0], -1]) + x = self.eEmbed(x) + return x + + +class AttentionRecognitionHead(nn.Layer): + """ + input: [b x 16 x 64 x in_planes] + output: probability sequence: [b x T x num_classes] + """ + + def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels): + super(AttentionRecognitionHead, self).__init__() + self.num_classes = out_channels # this is the output classes. So it includes the . + self.in_planes = in_channels + self.sDim = sDim + self.attDim = attDim + self.max_len_labels = max_len_labels + + self.decoder = DecoderUnit( + sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim) + + def forward(self, x, embed): + x, targets, lengths = x + batch_size = paddle.shape(x)[0] + # Decoder + state = self.decoder.get_initial_state(embed) + outputs = [] + + for i in range(max(lengths)): + if i == 0: + y_prev = paddle.full( + shape=[batch_size], fill_value=self.num_classes) + else: + y_prev = targets[:, i - 1] + + output, state = self.decoder(x, state, y_prev) + outputs.append(output) + outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1) + return outputs + + # inference stage. + def sample(self, x): + x, _, _ = x + batch_size = x.size(0) + # Decoder + state = paddle.zeros([1, batch_size, self.sDim]) + + predicted_ids, predicted_scores = [], [] + for i in range(self.max_len_labels): + if i == 0: + y_prev = paddle.full( + shape=[batch_size], fill_value=self.num_classes) + else: + y_prev = predicted + + output, state = self.decoder(x, state, y_prev) + output = F.softmax(output, axis=1) + score, predicted = output.max(1) + predicted_ids.append(predicted.unsqueeze(1)) + predicted_scores.append(score.unsqueeze(1)) + predicted_ids = paddle.concat([predicted_ids, 1]) + predicted_scores = paddle.concat([predicted_scores, 1]) + # return predicted_ids.squeeze(), predicted_scores.squeeze() + return predicted_ids, predicted_scores + + +class AttentionUnit(nn.Layer): + def __init__(self, sDim, xDim, attDim): + super(AttentionUnit, self).__init__() + + self.sDim = sDim + self.xDim = xDim + self.attDim = attDim + + self.sEmbed = nn.Linear( + sDim, + attDim, + weight_attr=paddle.nn.initializer.Normal(std=0.01), + bias_attr=paddle.nn.initializer.Constant(0.0)) + self.xEmbed = nn.Linear( + xDim, + attDim, + weight_attr=paddle.nn.initializer.Normal(std=0.01), + bias_attr=paddle.nn.initializer.Constant(0.0)) + self.wEmbed = nn.Linear( + attDim, + 1, + weight_attr=paddle.nn.initializer.Normal(std=0.01), + bias_attr=paddle.nn.initializer.Constant(0.0)) + + def forward(self, x, sPrev): + batch_size, T, _ = x.shape # [b x T x xDim] + x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim] + xProj = self.xEmbed(x) # [(b x T) x attDim] + xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim] + + sPrev = sPrev.squeeze(0) + sProj = self.sEmbed(sPrev) # [b x attDim] + sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim] + sProj = paddle.expand(sProj, + [batch_size, T, self.attDim]) # [b x T x attDim] + + sumTanh = paddle.tanh(sProj + xProj) + sumTanh = paddle.reshape(sumTanh, [-1, self.attDim]) + + vProj = self.wEmbed(sumTanh) # [(b x T) x 1] + vProj = paddle.reshape(vProj, [batch_size, T]) + + alpha = F.softmax( + vProj, axis=1) # attention weights for each sample in the minibatch + + return alpha + + +class DecoderUnit(nn.Layer): + def __init__(self, sDim, xDim, yDim, attDim): + super(DecoderUnit, self).__init__() + self.sDim = sDim + self.xDim = xDim + self.yDim = yDim + self.attDim = attDim + self.emdDim = attDim + + self.attention_unit = AttentionUnit(sDim, xDim, attDim) + self.tgt_embedding = nn.Embedding( + yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal( + std=0.01)) # the last is used for + self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim) + self.fc = nn.Linear( + sDim, + yDim, + weight_attr=nn.initializer.Normal(std=0.01), + bias_attr=nn.initializer.Constant(value=0)) + self.embed_fc = nn.Linear(300, self.sDim) + + def get_initial_state(self, embed, tile_times=1): + assert embed.shape[1] == 300 + state = self.embed_fc(embed) # N * sDim + if tile_times != 1: + state = state.unsqueeze(1) + trans_state = paddle.transpose(state, perm=[1, 0, 2]) + state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1]) + trans_state = paddle.transpose(state, perm=[1, 0, 2]) + state = paddle.reshape(trans_state, shape=[-1, self.sDim]) + state = state.unsqueeze(0) # 1 * N * sDim + return state + + def forward(self, x, sPrev, yPrev): + # x: feature sequence from the image decoder. + batch_size, T, _ = x.shape + alpha = self.attention_unit(x, sPrev) + context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1) + yPrev = paddle.cast(yPrev, dtype="int64") + yProj = self.tgt_embedding(yPrev) + + concat_context = paddle.concat([yProj, context], 1) + concat_context = paddle.squeeze(concat_context, 1) + sPrev = paddle.squeeze(sPrev, 0) + output, state = self.gru(concat_context, sPrev) + output = paddle.squeeze(output, axis=1) + output = self.fc(output) + return output, state + + +if __name__ == "__main__": + model = AttentionRecognitionHead( + num_classes=20, + in_channels=30, + sDim=512, + attDim=512, + max_len_labels=25, + out_channels=38) + + data = paddle.ones([16, 64, 3]) + targets = paddle.ones([16, 25]) + length = paddle.to_tensor(20) + x = [data, targets, length] + output = model(x) + print(output.shape) diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index 4286d7691..79f112f72 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -44,10 +44,13 @@ class AttentionHead(nn.Layer): hidden = paddle.zeros((batch_size, self.hidden_size)) output_hiddens = [] + targets = targets[0] + print(targets) if targets is not None: for i in range(num_steps): char_onehots = self._char_to_onehot( targets[:, i], onehot_dim=self.num_classes) + # print("char_onehots:", char_onehots) (outputs, hidden), alpha = self.attention_cell(hidden, inputs, char_onehots) output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) @@ -104,6 +107,8 @@ class AttentionGRUCell(nn.Layer): alpha = paddle.transpose(alpha, [0, 2, 1]) context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) concat_context = paddle.concat([context, char_onehots], 1) + # print("concat_context:", concat_context.shape) + # print("prev_hidden:", prev_hidden.shape) cur_hidden = self.rnn(concat_context, prev_hidden) diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py index 78eaecccc..0e02a1c0c 100755 --- a/ppocr/modeling/transforms/__init__.py +++ b/ppocr/modeling/transforms/__init__.py @@ -17,8 +17,9 @@ __all__ = ['build_transform'] def build_transform(config): from .tps import TPS + from .tps import STN_ON - support_dict = ['TPS'] + support_dict = ['TPS', 'STN_ON'] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py new file mode 100644 index 000000000..0b26e27ae --- /dev/null +++ b/ppocr/modeling/transforms/stn.py @@ -0,0 +1,121 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np + + +def conv3x3_block(in_channels, out_channels, stride=1): + n = 3 * 3 * out_channels + w = math.sqrt(2. / n) + conv_layer = nn.Conv2D( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + weight_attr=nn.initializer.Normal( + mean=0.0, std=w), + bias_attr=nn.initializer.Constant(0)) + block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU()) + return block + + +class STN(nn.Layer): + def __init__(self, in_channels, num_ctrlpoints, activation='none'): + super(STN, self).__init__() + self.in_channels = in_channels + self.num_ctrlpoints = num_ctrlpoints + self.activation = activation + self.stn_convnet = nn.Sequential( + conv3x3_block(in_channels, 32), #32x64 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(32, 64), #16x32 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(64, 128), # 8*16 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(128, 256), # 4*8 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(256, 256), # 2*4, + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(256, 256)) # 1*2 + self.stn_fc1 = nn.Sequential( + nn.Linear( + 2 * 256, + 512, + weight_attr=nn.initializer.Normal(0, 0.001), + bias_attr=nn.initializer.Constant(0)), + nn.BatchNorm1D(512), + nn.ReLU()) + fc2_bias = self.init_stn() + self.stn_fc2 = nn.Linear( + 512, + num_ctrlpoints * 2, + weight_attr=nn.initializer.Constant(0.0), + bias_attr=nn.initializer.Assign(fc2_bias)) + + def init_stn(self): + margin = 0.01 + sampling_num_per_side = int(self.num_ctrlpoints / 2) + ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side) + ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin + ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + ctrl_points = np.concatenate( + [ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) + if self.activation == 'none': + pass + elif self.activation == 'sigmoid': + ctrl_points = -np.log(1. / ctrl_points - 1.) + ctrl_points = paddle.to_tensor(ctrl_points) + fc2_bias = paddle.reshape( + ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]]) + return fc2_bias + + def forward(self, x): + x = self.stn_convnet(x) + batch_size, _, h, w = x.shape + x = paddle.reshape(x, shape=(batch_size, -1)) + img_feat = self.stn_fc1(x) + x = self.stn_fc2(0.1 * img_feat) + if self.activation == 'sigmoid': + x = F.sigmoid(x) + x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2]) + return img_feat, x + + +if __name__ == "__main__": + in_planes = 3 + num_ctrlpoints = 20 + np.random.seed(100) + activation = 'none' # 'sigmoid' + stn_head = STN(in_planes, num_ctrlpoints, activation) + data = np.random.randn(10, 3, 32, 64).astype("float32") + print("data:", np.sum(data)) + input = paddle.to_tensor(data) + #input = paddle.randn([10, 3, 32, 64]) + control_points = stn_head(input) diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index dcce6246a..fc4621007 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -22,6 +22,9 @@ from paddle import nn, ParamAttr from paddle.nn import functional as F import numpy as np +from .tps_spatial_transformer import TPSSpatialTransformer +from .stn import STN + class ConvBNLayer(nn.Layer): def __init__(self, @@ -231,7 +234,8 @@ class GridGenerator(nn.Layer): """ Return inv_delta_C which is needed to calculate T """ F = self.F hat_eye = paddle.eye(F, dtype='float64') # F x F - hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye + hat_C = paddle.norm( + C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye hat_C = (hat_C**2) * paddle.log(hat_C) delta_C = paddle.concat( # F+3 x F+3 [ @@ -301,3 +305,26 @@ class TPS(nn.Layer): [-1, image.shape[2], image.shape[3], 2]) batch_I_r = F.grid_sample(x=image, grid=batch_P_prime) return batch_I_r + + +class STN_ON(nn.Layer): + def __init__(self, in_channels, tps_inputsize, tps_outputsize, + num_control_points, tps_margins, stn_activation): + super(STN_ON, self).__init__() + self.tps = TPSSpatialTransformer( + output_image_size=tuple(tps_outputsize), + num_control_points=num_control_points, + margins=tuple(tps_margins)) + self.stn_head = STN(in_channels=in_channels, + num_ctrlpoints=num_control_points, + activation=stn_activation) + self.tps_inputsize = tps_inputsize + self.out_channels = in_channels + + def forward(self, image): + stn_input = paddle.nn.functional.interpolate( + image, self.tps_inputsize, mode="bilinear", align_corners=True) + stn_img_feat, ctrl_points = self.stn_head(stn_input) + x, _ = self.tps(image, ctrl_points) + # print(x.shape) + return x diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py new file mode 100644 index 000000000..da54ffb78 --- /dev/null +++ b/ppocr/modeling/transforms/tps_spatial_transformer.py @@ -0,0 +1,178 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np +import itertools + + +def grid_sample(input, grid, canvas=None): + input.stop_gradient = False + output = F.grid_sample(input, grid) + if canvas is None: + return output + else: + input_mask = paddle.ones(shape=input.shape) + output_mask = F.grid_sample(input_mask, grid) + padded_output = output * output_mask + canvas * (1 - output_mask) + return padded_output + + +# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 +def compute_partial_repr(input_points, control_points): + N = input_points.shape[0] + M = control_points.shape[0] + pairwise_diff = paddle.reshape( + input_points, shape=[N, 1, 2]) - paddle.reshape( + control_points, shape=[1, M, 2]) + # original implementation, very slow + # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance + pairwise_diff_square = pairwise_diff * pairwise_diff + pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, + 1] + repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist) + # fix numerical error for 0 * log(0), substitute all nan with 0 + mask = repr_matrix != repr_matrix + repr_matrix[mask] = 0 + return repr_matrix + + +# output_ctrl_pts are specified, according to our task. +def build_output_control_points(num_control_points, margins): + margin_x, margin_y = margins + num_ctrl_pts_per_side = num_control_points // 2 + ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) + ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y + ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + # ctrl_pts_top = ctrl_pts_top[1:-1,:] + # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] + output_ctrl_pts_arr = np.concatenate( + [ctrl_pts_top, ctrl_pts_bottom], axis=0) + output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr) + return output_ctrl_pts + + +class TPSSpatialTransformer(nn.Layer): + def __init__(self, + output_image_size=None, + num_control_points=None, + margins=None): + super(TPSSpatialTransformer, self).__init__() + self.output_image_size = output_image_size + self.num_control_points = num_control_points + self.margins = margins + + self.target_height, self.target_width = output_image_size + target_control_points = build_output_control_points(num_control_points, + margins) + N = num_control_points + # N = N - 4 + + # create padded kernel matrix + forward_kernel = paddle.zeros(shape=[N + 3, N + 3]) + target_control_partial_repr = compute_partial_repr( + target_control_points, target_control_points) + target_control_partial_repr = paddle.cast(target_control_partial_repr, + forward_kernel.dtype) + forward_kernel[:N, :N] = target_control_partial_repr + forward_kernel[:N, -3] = 1 + forward_kernel[-3, :N] = 1 + target_control_points = paddle.cast(target_control_points, + forward_kernel.dtype) + forward_kernel[:N, -2:] = target_control_points + forward_kernel[-2:, :N] = paddle.transpose( + target_control_points, perm=[1, 0]) + # compute inverse matrix + inverse_kernel = paddle.inverse(forward_kernel) + + # create target cordinate matrix + HW = self.target_height * self.target_width + target_coordinate = list( + itertools.product( + range(self.target_height), range(self.target_width))) + target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2 + Y, X = paddle.split( + target_coordinate, target_coordinate.shape[1], axis=1) + #Y, X = target_coordinate.split(1, dim = 1) + Y = Y / (self.target_height - 1) + X = X / (self.target_width - 1) + target_coordinate = paddle.concat( + [X, Y], axis=1) # convert from (y, x) to (x, y) + target_coordinate_partial_repr = compute_partial_repr( + target_coordinate, target_control_points) + target_coordinate_repr = paddle.concat( + [ + target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]), + target_coordinate + ], + axis=1) + + # register precomputed matrices + self.inverse_kernel = inverse_kernel + self.padding_matrix = paddle.zeros(shape=[3, 2]) + self.target_coordinate_repr = target_coordinate_repr + self.target_control_points = target_control_points + + def forward(self, input, source_control_points): + assert source_control_points.ndimension() == 3 + assert source_control_points.shape[1] == self.num_control_points + assert source_control_points.shape[2] == 2 + batch_size = source_control_points.shape[0] + + self.padding_matrix = paddle.expand( + self.padding_matrix, shape=[batch_size, 3, 2]) + Y = paddle.concat([source_control_points, self.padding_matrix], 1) + mapping_matrix = paddle.matmul(self.inverse_kernel, Y) + source_coordinate = paddle.matmul(self.target_coordinate_repr, + mapping_matrix) + + grid = paddle.reshape( + source_coordinate, + shape=[-1, self.target_height, self.target_width, 2]) + grid = paddle.clip(grid, 0, + 1) # the source_control_points may be out of [0, 1]. + # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] + # grid = 2.0 * grid - 1.0 + output_maps = grid_sample(input, grid, canvas=None) + return output_maps, source_coordinate + + +if __name__ == "__main__": + from stn import STN + in_planes = 3 + num_ctrlpoints = 20 + np.random.seed(100) + activation = 'none' # 'sigmoid' + stn_head = STN(in_planes, num_ctrlpoints, activation) + data = np.random.randn(10, 3, 32, 64).astype("float32") + input = paddle.to_tensor(data) + #input = paddle.randn([10, 3, 32, 64]) + control_points = stn_head(input) + #print("control points:", control_points) + #input = paddle.randn(shape=[10,3,32,100]) + tps = TPSSpatialTransformer( + output_image_size=[32, 320], + num_control_points=20, + margins=[0.05, 0.05]) + out = tps(input, control_points[1]) + print("out 0 :", out[0].shape) + print("out 1:", out[1].shape) diff --git a/ppocr/modeling/transforms/tps_torch.py b/ppocr/modeling/transforms/tps_torch.py new file mode 100644 index 000000000..7aee133ae --- /dev/null +++ b/ppocr/modeling/transforms/tps_torch.py @@ -0,0 +1,149 @@ +from __future__ import absolute_import + +import numpy as np +import itertools + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def grid_sample(input, grid, canvas=None): + output = F.grid_sample(input, grid) + if canvas is None: + return output + else: + input_mask = input.data.new(input.size()).fill_(1) + output_mask = F.grid_sample(input_mask, grid) + padded_output = output * output_mask + canvas * (1 - output_mask) + return padded_output + + +# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 +def compute_partial_repr(input_points, control_points): + N = input_points.size(0) + M = control_points.size(0) + pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2) + # original implementation, very slow + # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance + pairwise_diff_square = pairwise_diff * pairwise_diff + pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, + 1] + repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) + # fix numerical error for 0 * log(0), substitute all nan with 0 + mask = repr_matrix != repr_matrix + repr_matrix.masked_fill_(mask, 0) + return repr_matrix + + +# output_ctrl_pts are specified, according to our task. +def build_output_control_points(num_control_points, margins): + margin_x, margin_y = margins + num_ctrl_pts_per_side = num_control_points // 2 + ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) + ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y + ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + # ctrl_pts_top = ctrl_pts_top[1:-1,:] + # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] + output_ctrl_pts_arr = np.concatenate( + [ctrl_pts_top, ctrl_pts_bottom], axis=0) + output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) + return output_ctrl_pts + + +# demo: ~/test/models/test_tps_transformation.py +class TPSSpatialTransformer(nn.Module): + def __init__(self, + output_image_size=None, + num_control_points=None, + margins=None): + super(TPSSpatialTransformer, self).__init__() + self.output_image_size = output_image_size + self.num_control_points = num_control_points + self.margins = margins + + self.target_height, self.target_width = output_image_size + target_control_points = build_output_control_points(num_control_points, + margins) + N = num_control_points + # N = N - 4 + + # create padded kernel matrix + forward_kernel = torch.zeros(N + 3, N + 3) + target_control_partial_repr = compute_partial_repr( + target_control_points, target_control_points) + forward_kernel[:N, :N].copy_(target_control_partial_repr) + forward_kernel[:N, -3].fill_(1) + forward_kernel[-3, :N].fill_(1) + forward_kernel[:N, -2:].copy_(target_control_points) + forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) + # compute inverse matrix + inverse_kernel = torch.inverse(forward_kernel) + + # create target cordinate matrix + HW = self.target_height * self.target_width + target_coordinate = list( + itertools.product( + range(self.target_height), range(self.target_width))) + target_coordinate = torch.Tensor(target_coordinate) # HW x 2 + Y, X = target_coordinate.split(1, dim=1) + Y = Y / (self.target_height - 1) + X = X / (self.target_width - 1) + target_coordinate = torch.cat([X, Y], + dim=1) # convert from (y, x) to (x, y) + target_coordinate_partial_repr = compute_partial_repr( + target_coordinate, target_control_points) + target_coordinate_repr = torch.cat([ + target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate + ], + dim=1) + + # register precomputed matrices + self.register_buffer('inverse_kernel', inverse_kernel) + self.register_buffer('padding_matrix', torch.zeros(3, 2)) + self.register_buffer('target_coordinate_repr', target_coordinate_repr) + self.register_buffer('target_control_points', target_control_points) + + def forward(self, input, source_control_points): + assert source_control_points.ndimension() == 3 + assert source_control_points.size(1) == self.num_control_points + assert source_control_points.size(2) == 2 + batch_size = source_control_points.size(0) + + Y = torch.cat([ + source_control_points, self.padding_matrix.expand(batch_size, 3, 2) + ], 1) + mapping_matrix = torch.matmul(self.inverse_kernel, Y) + source_coordinate = torch.matmul(self.target_coordinate_repr, + mapping_matrix) + + grid = source_coordinate.view(-1, self.target_height, self.target_width, + 2) + grid = torch.clamp(grid, 0, + 1) # the source_control_points may be out of [0, 1]. + # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] + grid = 2.0 * grid - 1.0 + output_maps = grid_sample(input, grid, canvas=None) + return output_maps, source_coordinate + + +if __name__ == "__main__": + from stn_torch import STNHead + in_planes = 3 + num_ctrlpoints = 20 + torch.manual_seed(10) + activation = 'none' # 'sigmoid' + stn_head = STNHead(in_planes, num_ctrlpoints, activation) + np.random.seed(100) + data = np.random.randn(10, 3, 32, 64).astype("float32") + input = torch.tensor(data) + control_points = stn_head(input) + tps = TPSSpatialTransformer( + output_image_size=[32, 320], + num_control_points=20, + margins=[0.05, 0.05]) + out = tps(input, control_points[1]) + print("out 0 :", out[0].shape) + print("out 1:", out[1].shape) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 8426bcf2b..17fc7e461 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -170,8 +170,10 @@ class AttnLabelDecode(BaseRecLabelDecode): def add_special_char(self, dict_character): self.beg_str = "sos" self.end_str = "eos" + self.unkonwn = "UNKNOWN" dict_character = dict_character - dict_character = [self.beg_str] + dict_character + [self.end_str] + dict_character = [self.beg_str] + dict_character + [self.end_str + ] + [self.unkonwn] return dict_character def decode(self, text_index, text_prob=None, is_remove_duplicate=False): @@ -212,6 +214,7 @@ class AttnLabelDecode(BaseRecLabelDecode): label = self.decode(label, is_remove_duplicate=False) return text, label """ + preds = preds["rec_pred"] if isinstance(preds, paddle.Tensor): preds = preds.numpy() @@ -324,10 +327,9 @@ class SRNLabelDecode(BaseRecLabelDecode): class TableLabelDecode(object): """ """ - def __init__(self, - character_dict_path, - **kwargs): - list_character, list_elem = self.load_char_elem_dict(character_dict_path) + def __init__(self, character_dict_path, **kwargs): + list_character, list_elem = self.load_char_elem_dict( + character_dict_path) list_character = self.add_special_char(list_character) list_elem = self.add_special_char(list_elem) self.dict_character = {} @@ -366,14 +368,14 @@ class TableLabelDecode(object): def __call__(self, preds): structure_probs = preds['structure_probs'] loc_preds = preds['loc_preds'] - if isinstance(structure_probs,paddle.Tensor): + if isinstance(structure_probs, paddle.Tensor): structure_probs = structure_probs.numpy() - if isinstance(loc_preds,paddle.Tensor): + if isinstance(loc_preds, paddle.Tensor): loc_preds = loc_preds.numpy() structure_idx = structure_probs.argmax(axis=2) structure_probs = structure_probs.max(axis=2) - structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx, - structure_probs, 'elem') + structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode( + structure_idx, structure_probs, 'elem') res_html_code_list = [] res_loc_list = [] batch_num = len(structure_str) @@ -388,8 +390,13 @@ class TableLabelDecode(object): res_loc = np.array(res_loc) res_html_code_list.append(res_html_code) res_loc_list.append(res_loc) - return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list, - 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str} + return { + 'res_html_code': res_html_code_list, + 'res_loc': res_loc_list, + 'res_score_list': result_score_list, + 'res_elem_idx_list': result_elem_idx_list, + 'structure_str_list': structure_str + } def decode(self, text_index, structure_probs, char_or_elem): """convert text-label into text-index. diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 1d760e983..0453509c7 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -105,13 +105,16 @@ def load_dygraph_params(config, model, logger, optimizer): params = paddle.load(pm) state_dict = model.state_dict() new_state_dict = {} - for k1, k2 in zip(state_dict.keys(), params.keys()): - if list(state_dict[k1].shape) == list(params[k2].shape): - new_state_dict[k1] = params[k2] - else: - logger.info( - f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" - ) + # for k1, k2 in zip(state_dict.keys(), params.keys()): + for k1 in state_dict.keys(): + if k1 not in params: + continue + if list(state_dict[k1].shape) == list(params[k1].shape): + new_state_dict[k1] = params[k1] + else: + logger.info( + f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !" + ) model.set_state_dict(new_state_dict) logger.info(f"loaded pretrained_model successful from {pm}") return {} diff --git a/tools/program.py b/tools/program.py index 2d99f2968..920cf417c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -187,6 +187,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" model_type = config['Architecture']['model_type'] + algorithm = config['Architecture']['algorithm'] if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] @@ -210,10 +211,14 @@ def train(config, images = batch[0] if use_srn: model_average = True - if use_srn or model_type == 'table': - preds = model(images, data=batch[1:]) - else: - preds = model(images) + # if use_srn or model_type == 'table' or algorithm == "ASTER": + # preds = model(images, data=batch[1:]) + # else: + # preds = model(images) + preds = model(images, data=batch[1:]) + state_dict = model.state_dict() + # for key in state_dict: + # print(key) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -395,7 +400,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation', 'TableAttn' + 'CLS', 'PGNet', 'Distillation', 'TableAttn', 'ASTER' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' diff --git a/tools/train.py b/tools/train.py index 20f5a670d..e1515f57c 100755 --- a/tools/train.py +++ b/tools/train.py @@ -72,6 +72,8 @@ def main(config, device, logger, vdl_writer): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) + character = getattr(post_process_class, 'character') + print("getattr character:", character) if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: From c9e1077daac3efb2e5c42ebf879aa363d4c59db4 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Mon, 30 Aug 2021 06:32:54 +0000 Subject: [PATCH 2/3] polish code --- configs/rec/rec_resnet_stn_bilstm_att.yml | 65 +- ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/label_ops.py | 38 +- ppocr/data/imaug/operators.py | 16 +- ppocr/data/imaug/rec_img_aug.py | 23 + ppocr/data/simple_dataset.py | 1 - ppocr/losses/rec_aster_loss.py | 55 +- ppocr/losses/rec_att_loss.py | 2 - ppocr/metrics/rec_metric.py | 12 +- ppocr/modeling/backbones/__init__.py | 7 +- ppocr/modeling/backbones/levit.py | 707 ------------------ ppocr/modeling/heads/__init__.py | 1 - ppocr/modeling/heads/rec_aster_head.py | 208 +++++- ppocr/modeling/heads/rec_att_head.py | 5 - ppocr/modeling/transforms/stn.py | 13 - ppocr/modeling/transforms/tps.py | 1 + .../transforms/tps_spatial_transformer.py | 27 +- ppocr/modeling/transforms/tps_torch.py | 149 ---- ppocr/optimizer/optimizer.py | 31 + ppocr/postprocess/__init__.py | 4 +- ppocr/postprocess/rec_postprocess.py | 87 ++- ppocr/utils/save_load.py | 17 +- tools/program.py | 10 +- 23 files changed, 461 insertions(+), 1020 deletions(-) delete mode 100644 ppocr/modeling/backbones/levit.py delete mode 100644 ppocr/modeling/transforms/tps_torch.py diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml index f705f1e23..7b5a9c711 100644 --- a/configs/rec/rec_resnet_stn_bilstm_att.yml +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -1,9 +1,9 @@ Global: - use_gpu: False + use_gpu: True epoch_num: 400 log_smooth_window: 20 print_batch_step: 10 - save_model_dir: ./output/rec/b3_rare_r34_none_gru/ + save_model_dir: ./output/rec/seed save_epoch_step: 3 # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: [0, 2000] @@ -12,28 +12,32 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/imgs_words_en/word_10.png # for data or label process character_dict_path: character_type: EN_symbol - max_text_length: 25 + max_text_length: 100 infer_mode: False use_space_char: False - save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt + eval_filter: True + save_res_path: ./output/rec/predicts_seed.txt Optimizer: - name: Adam - beta1: 0.9 - beta2: 0.999 + name: Adadelta + weight_deacy: 0.0 + momentum: 0.9 lr: - learning_rate: 0.0005 + name: Piecewise + decay_epochs: [4,5,8] + values: [1.0, 0.1, 0.01] regularizer: name: 'L2' - factor: 0.00000 + factor: 2.0e-05 + Architecture: - model_type: rec + model_type: seed algorithm: ASTER Transform: name: STN_ON @@ -54,48 +58,49 @@ Loss: name: AsterLoss PostProcess: - name: AttnLabelDecode + name: SEEDLabelDecode Metric: name: RecMetric main_indicator: acc + is_filter: True Train: dataset: - name: SimpleDataSet - data_dir: ./train_data/ic15_data/ - label_file_list: ["./train_data/ic15_data/1.txt"] + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ transforms: + - Fasttext: + path: "./cc.en.300.bin" - DecodeImage: # load image img_mode: BGR channel_first: False - - AttnLabelEncode: # Class handling label - - RecResizeImg: - image_shape: [3, 32, 100] + - SEEDLabelEncode: # Class handling label + - SEEDResize: + image_shape: [3, 64, 256] - KeepKeys: - keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order loader: shuffle: True - batch_size_per_card: 2 + batch_size_per_card: 256 drop_last: True - num_workers: 8 + num_workers: 6 Eval: dataset: - name: SimpleDataSet - data_dir: ./train_data/ic15_data/ - label_file_list: ["./train_data/ic15_data/1.txt"] + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ transforms: - DecodeImage: # load image img_mode: BGR channel_first: False - - AttnLabelEncode: # Class handling label - - RecResizeImg: - image_shape: [3, 32, 100] + - SEEDLabelEncode: # Class handling label + - SEEDResize: + image_shape: [3, 64, 256] - KeepKeys: keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order loader: shuffle: False - drop_last: False - batch_size_per_card: 2 - num_workers: 8 + drop_last: True + batch_size_per_card: 256 + num_workers: 4 diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 52194eb96..7a792c2fe 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, SEEDResize from .randaugment import RandAugment from .copy_paste import CopyPaste from .operators import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 0e1d4939d..21d910304 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -276,9 +276,7 @@ class AttnLabelEncode(BaseRecLabelEncode): def add_special_char(self, dict_character): self.beg_str = "sos" self.end_str = "eos" - self.unknown = "UNKNOWN" - dict_character = [self.beg_str] + dict_character + [self.end_str - ] + [self.unknown] + dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character def __call__(self, data): @@ -291,7 +289,6 @@ class AttnLabelEncode(BaseRecLabelEncode): data['length'] = np.array(len(text)) text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len - len(text) - 2) - data['label'] = np.array(text) return data @@ -311,6 +308,39 @@ class AttnLabelEncode(BaseRecLabelEncode): return idx +class SEEDLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SEEDLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + [self.end_str] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len: + return None + data['length'] = np.array(len(text)) + 1 # conclue eos + text = text + [len(self.character) - 1] * (self.max_text_len - len(text) + ) + data['label'] = np.array(text) + return data + + class SRNLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 2535b4420..ba5f01b4e 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -23,6 +23,7 @@ import sys import six import cv2 import numpy as np +import fasttext class DecodeImage(object): @@ -81,7 +82,7 @@ class NormalizeImage(object): assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" data['image'] = ( - img.astype('float32') * self.scale - self.mean) / self.std + img.astype('float32') * self.scale - self.mean) / self.std return data @@ -101,6 +102,17 @@ class ToCHWImage(object): return data +class Fasttext(object): + def __init__(self, path="None", **kwargs): + self.fast_model = fasttext.load_model(path) + + def __call__(self, data): + label = data['label'] + fast_label = self.fast_model[label] + data['fast_label'] = fast_label + return data + + class KeepKeys(object): def __init__(self, keep_keys, **kwargs): self.keep_keys = keep_keys @@ -183,7 +195,7 @@ class DetResizeForTest(object): else: ratio = 1. elif self.limit_type == 'resize_long': - ratio = float(limit_side_len) / max(h,w) + ratio = float(limit_side_len) / max(h, w) else: raise Exception('not support limit type, image ') resize_h = int(h * ratio) diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 28e6bd0bc..ed5b7a52c 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -63,6 +63,18 @@ class RecResizeImg(object): return data +class SEEDResize(object): + def __init__(self, image_shape, infer_mode=False, **kwargs): + self.image_shape = image_shape + self.infer_mode = infer_mode + + def __call__(self, data): + img = data['image'] + norm_img = resize_no_padding_img(img, self.image_shape) + data['image'] = norm_img + return data + + class SRNRecResizeImg(object): def __init__(self, image_shape, num_heads, max_text_length, **kwargs): self.image_shape = image_shape @@ -106,6 +118,17 @@ def resize_norm_img(img, image_shape): return padding_im +def resize_no_padding_img(img, image_shape): + imgC, imgH, imgW = image_shape + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + return resized_image + + def resize_norm_img_chinese(img, image_shape): imgC, imgH, imgW = image_shape # todo: change to 0 and modified image shape diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index b519f4fde..ce9e1b386 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -22,7 +22,6 @@ from .imaug import transform, create_operators class SimpleDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): - print("===== simpledataset ========") super(SimpleDataSet, self).__init__() self.logger = logger self.mode = mode.lower() diff --git a/ppocr/losses/rec_aster_loss.py b/ppocr/losses/rec_aster_loss.py index 858fadc02..d900617ff 100644 --- a/ppocr/losses/rec_aster_loss.py +++ b/ppocr/losses/rec_aster_loss.py @@ -18,7 +18,26 @@ from __future__ import print_function import paddle from paddle import nn -import fasttext + + +class CosineEmbeddingLoss(nn.Layer): + def __init__(self, margin=0.): + super(CosineEmbeddingLoss, self).__init__() + self.margin = margin + + def forward(self, x1, x2, target): + similarity = paddle.fluid.layers.reduce_sum( + x1 * x2, dim=-1) / (paddle.norm( + x1, axis=-1) * paddle.norm( + x2, axis=-1)) + one_list = paddle.full_like(target, fill_value=1) + out = paddle.fluid.layers.reduce_mean( + paddle.where( + paddle.equal(target, one_list), 1. - similarity, + paddle.maximum( + paddle.zeros_like(similarity), similarity - self.margin))) + + return out class AsterLoss(nn.Layer): @@ -35,28 +54,28 @@ class AsterLoss(nn.Layer): self.ignore_index = ignore_index self.sequence_normalize = sequence_normalize self.sample_normalize = sample_normalize - self.loss_func = paddle.nn.CosineSimilarity() + self.loss_sem = CosineEmbeddingLoss() + self.is_cosin_loss = True + self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none') def forward(self, predicts, batch): targets = batch[1].astype("int64") label_lengths = batch[2].astype('int64') - # sem_target = batch[3].astype('float32') + sem_target = batch[3].astype('float32') embedding_vectors = predicts['embedding_vectors'] rec_pred = predicts['rec_pred'] - # semantic loss - # print(embedding_vectors) - # print(embedding_vectors.shape) - # targets = fasttext[targets] - # sem_loss = 1 - self.loss_func(embedding_vectors, targets) + if not self.is_cosin_loss: + sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target)) + else: + label_target = paddle.ones([embedding_vectors.shape[0]]) + sem_loss = paddle.sum( + self.loss_sem(embedding_vectors, sem_target, label_target)) # rec loss - batch_size, num_steps, num_classes = rec_pred.shape[0], rec_pred.shape[ - 1], rec_pred.shape[2] - assert len(targets.shape) == len(list(rec_pred.shape)) - 1, \ - "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + batch_size, def_max_length = targets.shape[0], targets.shape[1] - mask = paddle.zeros([batch_size, num_steps]) + mask = paddle.zeros([batch_size, def_max_length]) for i in range(batch_size): mask[i, :label_lengths[i]] = 1 mask = paddle.cast(mask, "float32") @@ -64,16 +83,16 @@ class AsterLoss(nn.Layer): assert max_length == rec_pred.shape[1] targets = targets[:, :max_length] mask = mask[:, :max_length] - rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[-1]]) + rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]]) input = nn.functional.log_softmax(rec_pred, axis=1) targets = paddle.reshape(targets, [-1, 1]) mask = paddle.reshape(mask, [-1, 1]) - # print("input:", input) - output = -paddle.gather(input, index=targets, axis=1) * mask + output = -paddle.index_sample(input, index=targets) * mask output = paddle.sum(output) if self.sequence_normalize: output = output / paddle.sum(mask) if self.sample_normalize: output = output / batch_size - loss = output - return {'loss': loss} # , 'sem_loss':sem_loss} + + loss = output + sem_loss * 0.1 + return {'loss': loss} diff --git a/ppocr/losses/rec_att_loss.py b/ppocr/losses/rec_att_loss.py index 2d8d64b9d..6e2f67483 100644 --- a/ppocr/losses/rec_att_loss.py +++ b/ppocr/losses/rec_att_loss.py @@ -35,7 +35,5 @@ class AttentionLoss(nn.Layer): inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]]) targets = paddle.reshape(targets, [-1]) - print("input:", paddle.argmax(inputs, axis=1)) - print("targets:", targets) return {'loss': paddle.sum(self.loss_func(inputs, targets))} diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 66c084d77..db2f41c3a 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -13,13 +13,20 @@ # limitations under the License. import Levenshtein +import string class RecMetric(object): - def __init__(self, main_indicator='acc', **kwargs): + def __init__(self, main_indicator='acc', is_filter=False, **kwargs): self.main_indicator = main_indicator + self.is_filter = is_filter self.reset() + def _normalize_text(self, text): + text = ''.join( + filter(lambda x: x in (string.digits + string.ascii_letters), text)) + return text.lower() + def __call__(self, pred_label, *args, **kwargs): preds, labels = pred_label correct_num = 0 @@ -28,6 +35,9 @@ class RecMetric(object): for (pred, pred_conf), (target, _) in zip(preds, labels): pred = pred.replace(" ", "") target = target.replace(" ", "") + if self.is_filter: + pred = self._normalize_text(pred) + target = self._normalize_text(target) norm_edit_dis += Levenshtein.distance(pred, target) / max( len(pred), len(target), 1) if pred == target: diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index e0bc45b47..25cedb162 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -26,10 +26,8 @@ def build_backbone(config, model_type): from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN from .rec_mv1_enhance import MobileNetV1Enhance - from .rec_resnet_aster import ResNet_ASTER support_dict = [ - "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN", - "ResNet_ASTER" + "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet @@ -38,6 +36,9 @@ def build_backbone(config, model_type): from .table_resnet_vd import ResNet from .table_mobilenet_v3 import MobileNetV3 support_dict = ["ResNet", "MobileNetV3"] + elif model_type == "seed": + from .rec_resnet_aster import ResNet_ASTER + support_dict = ["ResNet_ASTER"] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/levit.py b/ppocr/modeling/backbones/levit.py deleted file mode 100644 index 8b04e9def..000000000 --- a/ppocr/modeling/backbones/levit.py +++ /dev/null @@ -1,707 +0,0 @@ -# Copyright (c) 2015-present, Facebook, Inc. -# All rights reserved. - -# Modified from -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -# Copyright 2020 Ross Wightman, Apache-2.0 License - -import paddle -import itertools -#import utils -import math -import warnings -import paddle.nn.functional as F -from paddle.nn.initializer import TruncatedNormal, Constant - -#from timm.models.vision_transformer import trunc_normal_ -#from timm.models.registry import register_model - -specification = { - 'LeViT_128S': { - 'C': '128_256_384', - 'D': 16, - 'N': '4_6_8', - 'X': '2_3_4', - 'drop_path': 0, - 'weights': - 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' - }, - 'LeViT_128': { - 'C': '128_256_384', - 'D': 16, - 'N': '4_8_12', - 'X': '4_4_4', - 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' - }, - 'LeViT_192': { - 'C': '192_288_384', - 'D': 32, - 'N': '3_5_6', - 'X': '4_4_4', - 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' - }, - 'LeViT_256': { - 'C': '256_384_512', - 'D': 32, - 'N': '4_6_8', - 'X': '4_4_4', - 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' - }, - 'LeViT_384': { - 'C': '384_512_768', - 'D': 32, - 'N': '6_9_12', - 'X': '4_4_4', - 'drop_path': 0.1, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' - }, -} - -__all__ = [specification.keys()] - -trunc_normal_ = TruncatedNormal(std=.02) -zeros_ = Constant(value=0.) -ones_ = Constant(value=1.) - - -#@register_model -def LeViT_128S(class_dim=1000, distillation=True, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_128S'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -#@register_model -def LeViT_128(class_dim=1000, distillation=True, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_128'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -#@register_model -def LeViT_192(class_dim=1000, distillation=True, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_192'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -#@register_model -def LeViT_256(class_dim=1000, distillation=False, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_256'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -#@register_model -def LeViT_384(class_dim=1000, distillation=True, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_384'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -FLOPS_COUNTER = 0 - - -class Conv2d_BN(paddle.nn.Sequential): - def __init__(self, - a, - b, - ks=1, - stride=1, - pad=0, - dilation=1, - groups=1, - bn_weight_init=1, - resolution=-10000): - super().__init__() - self.add_sublayer( - 'c', - paddle.nn.Conv2D( - a, b, ks, stride, pad, dilation, groups, bias_attr=False)) - bn = paddle.nn.BatchNorm2D(b) - ones_(bn.weight) - zeros_(bn.bias) - self.add_sublayer('bn', bn) - - global FLOPS_COUNTER - output_points = ( - (resolution + 2 * pad - dilation * (ks - 1) - 1) // stride + 1)**2 - FLOPS_COUNTER += a * b * output_points * (ks**2) - - @paddle.no_grad() - def fuse(self): - c, bn = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps)**0.5 - w = c.weight * w[:, None, None, None] - b = bn.bias - bn.running_mean * bn.weight / \ - (bn.running_var + bn.eps)**0.5 - m = paddle.nn.Conv2D( - w.size(1), - w.size(0), - w.shape[2:], - stride=self.c.stride, - padding=self.c.padding, - dilation=self.c.dilation, - groups=self.c.groups) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -class Linear_BN(paddle.nn.Sequential): - def __init__(self, a, b, bn_weight_init=1, resolution=-100000): - super().__init__() - self.add_sublayer('c', paddle.nn.Linear(a, b, bias_attr=False)) - bn = paddle.nn.BatchNorm1D(b) - ones_(bn.weight) - zeros_(bn.bias) - self.add_sublayer('bn', bn) - - global FLOPS_COUNTER - output_points = resolution**2 - FLOPS_COUNTER += a * b * output_points - - @paddle.no_grad() - def fuse(self): - l, bn = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps)**0.5 - w = l.weight * w[:, None] - b = bn.bias - bn.running_mean * bn.weight / \ - (bn.running_var + bn.eps)**0.5 - m = paddle.nn.Linear(w.size(1), w.size(0)) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - def forward(self, x): - l, bn = self._sub_layers.values() - x = l(x) - return paddle.reshape(bn(x.flatten(0, 1)), x.shape) - - -class BN_Linear(paddle.nn.Sequential): - def __init__(self, a, b, bias=True, std=0.02): - super().__init__() - self.add_sublayer('bn', paddle.nn.BatchNorm1D(a)) - l = paddle.nn.Linear(a, b, bias_attr=bias) - trunc_normal_(l.weight) - if bias: - zeros_(l.bias) - self.add_sublayer('l', l) - global FLOPS_COUNTER - FLOPS_COUNTER += a * b - - @paddle.no_grad() - def fuse(self): - bn, l = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps)**0.5 - b = bn.bias - self.bn.running_mean * \ - self.bn.weight / (bn.running_var + bn.eps)**0.5 - w = l.weight * w[None, :] - if l.bias is None: - b = b @self.l.weight.T - else: - b = (l.weight @b[:, None]).view(-1) + self.l.bias - m = paddle.nn.Linear(w.size(1), w.size(0)) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -def b16(n, activation, resolution=224): - return paddle.nn.Sequential( - Conv2d_BN( - 3, n // 8, 3, 2, 1, resolution=resolution), - activation(), - Conv2d_BN( - n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), - activation(), - Conv2d_BN( - n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), - activation(), - Conv2d_BN( - n // 2, n, 3, 2, 1, resolution=resolution // 8)) - - -class Residual(paddle.nn.Layer): - def __init__(self, m, drop): - super().__init__() - self.m = m - self.drop = drop - - def forward(self, x): - if self.training and self.drop > 0: - return x + self.m(x) * paddle.rand( - x.size(0), 1, 1, - device=x.device).ge_(self.drop).div(1 - self.drop).detach() - else: - return x + self.m(x) - - -class Attention(paddle.nn.Layer): - def __init__(self, - dim, - key_dim, - num_heads=8, - attn_ratio=4, - activation=None, - resolution=14): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim**-0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * num_heads - self.attn_ratio = attn_ratio - self.h = self.dh + nh_kd * 2 - self.qkv = Linear_BN(dim, self.h, resolution=resolution) - self.proj = paddle.nn.Sequential( - activation(), - Linear_BN( - self.dh, dim, bn_weight_init=0, resolution=resolution)) - points = list(itertools.product(range(resolution), range(resolution))) - N = len(points) - attention_offsets = {} - idxs = [] - for p1 in points: - for p2 in points: - offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = self.create_parameter( - shape=(num_heads, len(attention_offsets)), - default_initializer=zeros_) - tensor_idxs = paddle.to_tensor(idxs, dtype='int64') - self.register_buffer('attention_bias_idxs', - paddle.reshape(tensor_idxs, [N, N])) - - global FLOPS_COUNTER - #queries * keys - FLOPS_COUNTER += num_heads * (resolution**4) * key_dim - # softmax - FLOPS_COUNTER += num_heads * (resolution**4) - #attention * v - FLOPS_COUNTER += num_heads * self.d * (resolution**4) - - @paddle.no_grad() - def train(self, mode=True): - if mode: - super().train() - else: - super().eval() - if mode and hasattr(self, 'ab'): - del self.ab - else: - gather_list = [] - attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) - for idx in self.attention_bias_idxs: - gather = paddle.gather(attention_bias_t, idx) - gather_list.append(gather) - attention_biases = paddle.transpose( - paddle.concat(gather_list), (1, 0)).reshape( - (0, self.attention_bias_idxs.shape[0], - self.attention_bias_idxs.shape[1])) - self.ab = attention_biases - #self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): # x (B,N,C) - self.training = True - B, N, C = x.shape - qkv = self.qkv(x) - qkv = paddle.reshape(qkv, - [B, N, self.num_heads, self.h // self.num_heads]) - q, k, v = paddle.split( - qkv, [self.key_dim, self.key_dim, self.d], axis=3) - q = paddle.transpose(q, perm=[0, 2, 1, 3]) - k = paddle.transpose(k, perm=[0, 2, 1, 3]) - v = paddle.transpose(v, perm=[0, 2, 1, 3]) - k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2]) - - if self.training: - gather_list = [] - attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) - for idx in self.attention_bias_idxs: - gather = paddle.gather(attention_bias_t, idx) - gather_list.append(gather) - attention_biases = paddle.transpose( - paddle.concat(gather_list), (1, 0)).reshape( - (0, self.attention_bias_idxs.shape[0], - self.attention_bias_idxs.shape[1])) - else: - attention_biases = self.ab - #np_ = paddle.to_tensor(self.attention_biases.numpy()[:, self.attention_bias_idxs.numpy()]) - #print(self.attention_bias_idxs.shape) - #print(attention_biases.shape) - #print(np_.shape) - #print(np_.equal(attention_biases)) - #exit() - - attn = ((q @k_transpose) * self.scale + attention_biases) - attn = F.softmax(attn) - x = paddle.transpose(attn @v, perm=[0, 2, 1, 3]) - x = paddle.reshape(x, [B, N, self.dh]) - x = self.proj(x) - return x - - -class Subsample(paddle.nn.Layer): - def __init__(self, stride, resolution): - super().__init__() - self.stride = stride - self.resolution = resolution - - def forward(self, x): - B, N, C = x.shape - x = paddle.reshape(x, [B, self.resolution, self.resolution, - C])[:, ::self.stride, ::self.stride] - x = paddle.reshape(x, [B, -1, C]) - return x - - -class AttentionSubsample(paddle.nn.Layer): - def __init__(self, - in_dim, - out_dim, - key_dim, - num_heads=8, - attn_ratio=2, - activation=None, - stride=2, - resolution=14, - resolution_=7): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim**-0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * self.num_heads - self.attn_ratio = attn_ratio - self.resolution_ = resolution_ - self.resolution_2 = resolution_**2 - self.training = True - h = self.dh + nh_kd - self.kv = Linear_BN(in_dim, h, resolution=resolution) - - self.q = paddle.nn.Sequential( - Subsample(stride, resolution), - Linear_BN( - in_dim, nh_kd, resolution=resolution_)) - self.proj = paddle.nn.Sequential( - activation(), Linear_BN( - self.dh, out_dim, resolution=resolution_)) - - self.stride = stride - self.resolution = resolution - points = list(itertools.product(range(resolution), range(resolution))) - points_ = list( - itertools.product(range(resolution_), range(resolution_))) - - N = len(points) - N_ = len(points_) - attention_offsets = {} - idxs = [] - i = 0 - j = 0 - for p1 in points_: - i += 1 - for p2 in points: - j += 1 - size = 1 - offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), - abs(p1[1] * stride - p2[1] + (size - 1) / 2)) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = self.create_parameter( - shape=(num_heads, len(attention_offsets)), - default_initializer=zeros_) - - tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64') - self.register_buffer('attention_bias_idxs', - paddle.reshape(tensor_idxs_, [N_, N])) - - global FLOPS_COUNTER - #queries * keys - FLOPS_COUNTER += num_heads * \ - (resolution**2) * (resolution_**2) * key_dim - # softmax - FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2) - #attention * v - FLOPS_COUNTER += num_heads * \ - (resolution**2) * (resolution_**2) * self.d - - @paddle.no_grad() - def train(self, mode=True): - if mode: - super().train() - else: - super().eval() - if mode and hasattr(self, 'ab'): - del self.ab - else: - gather_list = [] - attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) - for idx in self.attention_bias_idxs: - gather = paddle.gather(attention_bias_t, idx) - gather_list.append(gather) - attention_biases = paddle.transpose( - paddle.concat(gather_list), (1, 0)).reshape( - (0, self.attention_bias_idxs.shape[0], - self.attention_bias_idxs.shape[1])) - self.ab = attention_biases - #self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): - self.training = True - B, N, C = x.shape - kv = self.kv(x) - kv = paddle.reshape(kv, [B, N, self.num_heads, -1]) - k, v = paddle.split(kv, [self.key_dim, self.d], axis=3) - k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC - v = paddle.transpose(v, perm=[0, 2, 1, 3]) - q = paddle.reshape( - self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim]) - q = paddle.transpose(q, perm=[0, 2, 1, 3]) - - if self.training: - gather_list = [] - attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) - for idx in self.attention_bias_idxs: - gather = paddle.gather(attention_bias_t, idx) - gather_list.append(gather) - attention_biases = paddle.transpose( - paddle.concat(gather_list), (1, 0)).reshape( - (0, self.attention_bias_idxs.shape[0], - self.attention_bias_idxs.shape[1])) - else: - attention_biases = self.ab - - attn = (q @paddle.transpose( - k, perm=[0, 1, 3, 2])) * self.scale + attention_biases - attn = F.softmax(attn) - - x = paddle.reshape( - paddle.transpose( - (attn @v), perm=[0, 2, 1, 3]), [B, -1, self.dh]) - x = self.proj(x) - return x - - -class LeViT(paddle.nn.Layer): - """ Vision Transformer with support for patch or hybrid CNN input stage - """ - - def __init__(self, - img_size=224, - patch_size=16, - in_chans=3, - class_dim=1000, - embed_dim=[192], - key_dim=[64], - depth=[12], - num_heads=[3], - attn_ratio=[2], - mlp_ratio=[2], - hybrid_backbone=None, - down_ops=[], - attention_activation=paddle.nn.Hardswish, - mlp_activation=paddle.nn.Hardswish, - distillation=True, - drop_path=0): - super().__init__() - global FLOPS_COUNTER - - self.class_dim = class_dim - self.num_features = embed_dim[-1] - self.embed_dim = embed_dim - self.distillation = distillation - - self.patch_embed = hybrid_backbone - - self.blocks = [] - down_ops.append(['']) - resolution = img_size // patch_size - for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( - zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, - down_ops)): - for _ in range(dpth): - self.blocks.append( - Residual( - Attention( - ed, - kd, - nh, - attn_ratio=ar, - activation=attention_activation, - resolution=resolution, ), - drop_path)) - if mr > 0: - h = int(ed * mr) - self.blocks.append( - Residual( - paddle.nn.Sequential( - Linear_BN( - ed, h, resolution=resolution), - mlp_activation(), - Linear_BN( - h, - ed, - bn_weight_init=0, - resolution=resolution), ), - drop_path)) - if do[0] == 'Subsample': - #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - resolution_ = (resolution - 1) // do[5] + 1 - self.blocks.append( - AttentionSubsample( - *embed_dim[i:i + 2], - key_dim=do[1], - num_heads=do[2], - attn_ratio=do[3], - activation=attention_activation, - stride=do[5], - resolution=resolution, - resolution_=resolution_)) - resolution = resolution_ - if do[4] > 0: # mlp_ratio - h = int(embed_dim[i + 1] * do[4]) - self.blocks.append( - Residual( - paddle.nn.Sequential( - Linear_BN( - embed_dim[i + 1], h, resolution=resolution), - mlp_activation(), - Linear_BN( - h, - embed_dim[i + 1], - bn_weight_init=0, - resolution=resolution), ), - drop_path)) - self.blocks = paddle.nn.Sequential(*self.blocks) - - # Classifier head - self.head = BN_Linear( - embed_dim[-1], class_dim) if class_dim > 0 else paddle.nn.Identity() - if distillation: - self.head_dist = BN_Linear( - embed_dim[-1], - class_dim) if class_dim > 0 else paddle.nn.Identity() - - self.FLOPS = FLOPS_COUNTER - FLOPS_COUNTER = 0 - - def no_weight_decay(self): - return {x for x in self.state_dict().keys() if 'attention_biases' in x} - - def forward(self, x): - x = self.patch_embed(x) - x = x.flatten(2) - x = paddle.transpose(x, perm=[0, 2, 1]) - x = self.blocks(x) - x = x.mean(1) - if self.distillation: - x = self.head(x), self.head_dist(x) - if not self.training: - x = (x[0] + x[1]) / 2 - else: - x = self.head(x) - return x - - -def model_factory(C, D, X, N, drop_path, weights, class_dim, distillation, - pretrained, fuse): - embed_dim = [int(x) for x in C.split('_')] - num_heads = [int(x) for x in N.split('_')] - depth = [int(x) for x in X.split('_')] - act = paddle.nn.Hardswish - model = LeViT( - patch_size=16, - embed_dim=embed_dim, - num_heads=num_heads, - key_dim=[D] * 3, - depth=depth, - attn_ratio=[2, 2, 2], - mlp_ratio=[2, 2, 2], - down_ops=[ - #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - ['Subsample', D, embed_dim[0] // D, 4, 2, 2], - ['Subsample', D, embed_dim[1] // D, 4, 2, 2], - ], - attention_activation=act, - mlp_activation=act, - hybrid_backbone=b16(embed_dim[0], activation=act), - class_dim=class_dim, - drop_path=drop_path, - distillation=distillation) - # if pretrained: - # checkpoint = torch.hub.load_state_dict_from_url( - # weights, map_location='cpu') - # model.load_state_dict(checkpoint['model']) - if fuse: - utils.replace_batchnorm(model) - - return model - - -if __name__ == '__main__': - ''' - import torch - checkpoint = torch.load('../LeViT/pretrained256.pth') - torch_dict = checkpoint['net'] - paddle_dict = {} - fc_names = ["c.weight", "l.weight", "qkv.weight", "fc1.weight", "fc2.weight", "downsample.reduction.weight", "head.weight", "attn.proj.weight"] - rename_dict = {"running_mean": "_mean", "running_var": "_variance"} - range_tuple = (0, 502) - idx = 0 - for key in torch_dict: - idx += 1 - weight = torch_dict[key].cpu().numpy() - flag = [i in key for i in fc_names] - if any(flag): - if "emb" not in key: - print("weight {} need to be trans".format(key)) - weight = weight.transpose() - key = key.replace("running_mean", "_mean") - key = key.replace("running_var", "_variance") - paddle_dict[key]=weight - ''' - import numpy as np - net = globals()['LeViT_256'](fuse=False, - pretrained=False, - distillation=False) - load_layer_state_dict = paddle.load( - "./LeViT_256_official_nodistillation_paddle.pdparams") - #net.set_state_dict(paddle_dict) - net.set_state_dict(load_layer_state_dict) - net.eval() - #paddle.save(net.state_dict(), "./LeViT_256_official_paddle.pdparams") - #model = paddle.jit.to_static(net,input_spec=[paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype='float32')]) - #paddle.jit.save(model, "./LeViT_256_official_inference/inference") - #exit() - np.random.seed(123) - img = np.random.rand(1, 3, 224, 224).astype('float32') - img = paddle.to_tensor(img) - outputs = net(img).numpy() - print(outputs[0][:10]) - #print(outputs.shape) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index cd923d78b..c04ff81ad 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -42,6 +42,5 @@ def build_head(config): module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( support_dict)) - print(config) module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py index 055b10973..ed520669e 100644 --- a/ppocr/modeling/heads/rec_aster_head.py +++ b/ppocr/modeling/heads/rec_aster_head.py @@ -43,13 +43,14 @@ class AsterHead(nn.Layer): self.time_step = time_step self.embeder = Embedding(self.time_step, in_channels) self.beam_width = beam_width + self.eos = self.num_classes - 1 def forward(self, x, targets=None, embed=None): return_dict = {} embedding_vectors = self.embeder(x) - rec_targets, rec_lengths = targets if self.training: + rec_targets, rec_lengths, _ = targets rec_pred = self.decoder([x, rec_targets, rec_lengths], embedding_vectors) return_dict['rec_pred'] = rec_pred @@ -104,14 +105,12 @@ class AttentionRecognitionHead(nn.Layer): # Decoder state = self.decoder.get_initial_state(embed) outputs = [] - for i in range(max(lengths)): if i == 0: y_prev = paddle.full( shape=[batch_size], fill_value=self.num_classes) else: y_prev = targets[:, i - 1] - output, state = self.decoder(x, state, y_prev) outputs.append(output) outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1) @@ -142,6 +141,170 @@ class AttentionRecognitionHead(nn.Layer): # return predicted_ids.squeeze(), predicted_scores.squeeze() return predicted_ids, predicted_scores + def beam_search(self, x, beam_width, eos, embed): + def _inflate(tensor, times, dim): + repeat_dims = [1] * tensor.dim() + repeat_dims[dim] = times + output = paddle.tile(tensor, repeat_dims) + return output + + # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py + batch_size, l, d = x.shape + # inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC + x = paddle.tile( + paddle.transpose( + x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1]) + inflated_encoder_feats = paddle.reshape( + paddle.transpose( + x, perm=[1, 0, 2, 3]), [-1, l, d]) + + # Initialize the decoder + state = self.decoder.get_initial_state(embed, tile_times=beam_width) + + pos_index = paddle.reshape( + paddle.arange(batch_size) * beam_width, shape=[-1, 1]) + + # Initialize the scores + sequence_scores = paddle.full( + shape=[batch_size * beam_width, 1], fill_value=-float('Inf')) + index = [i * beam_width for i in range(0, batch_size)] + sequence_scores[index] = 0.0 + + # Initialize the input vector + y_prev = paddle.full( + shape=[batch_size * beam_width], fill_value=self.num_classes) + + # Store decisions for backtracking + stored_scores = list() + stored_predecessors = list() + stored_emitted_symbols = list() + + for i in range(self.max_len_labels): + output, state = self.decoder(inflated_encoder_feats, state, y_prev) + state = paddle.unsqueeze(state, axis=0) + log_softmax_output = paddle.nn.functional.log_softmax( + output, axis=1) + + sequence_scores = _inflate(sequence_scores, self.num_classes, 1) + sequence_scores += log_softmax_output + scores, candidates = paddle.topk( + paddle.reshape(sequence_scores, [batch_size, -1]), + beam_width, + axis=1) + + # Reshape input = (bk, 1) and sequence_scores = (bk, 1) + y_prev = paddle.reshape( + candidates % self.num_classes, shape=[batch_size * beam_width]) + sequence_scores = paddle.reshape( + scores, shape=[batch_size * beam_width, 1]) + + # Update fields for next timestep + pos_index = paddle.expand_as(pos_index, candidates) + predecessors = paddle.cast( + candidates / self.num_classes + pos_index, dtype='int64') + predecessors = paddle.reshape( + predecessors, shape=[batch_size * beam_width, 1]) + state = paddle.index_select( + state, index=predecessors.squeeze(), axis=1) + + # Update sequence socres and erase scores for symbol so that they aren't expanded + stored_scores.append(sequence_scores.clone()) + y_prev = paddle.reshape(y_prev, shape=[-1, 1]) + eos_prev = paddle.full_like(y_prev, fill_value=eos) + mask = eos_prev == y_prev + mask = paddle.nonzero(mask) + if mask.dim() > 0: + sequence_scores = sequence_scores.numpy() + mask = mask.numpy() + sequence_scores[mask] = -float('inf') + sequence_scores = paddle.to_tensor(sequence_scores) + + # Cache results for backtracking + stored_predecessors.append(predecessors) + y_prev = paddle.squeeze(y_prev) + stored_emitted_symbols.append(y_prev) + + # Do backtracking to return the optimal values + #====== backtrak ======# + # Initialize return variables given different types + p = list() + l = [[self.max_len_labels] * beam_width for _ in range(batch_size) + ] # Placeholder for lengths of top-k sequences + + # the last step output of the beams are not sorted + # thus they are sorted here + sorted_score, sorted_idx = paddle.topk( + paddle.reshape( + stored_scores[-1], shape=[batch_size, beam_width]), + beam_width) + + # initialize the sequence scores with the sorted last step beam scores + s = sorted_score.clone() + + batch_eos_found = [0] * batch_size # the number of EOS found + # in the backward loop below for each batch + t = self.max_len_labels - 1 + # initialize the back pointer with the sorted order of the last step beams. + # add pos_index for indexing variable with b*k as the first dimension. + t_predecessors = paddle.reshape( + sorted_idx + pos_index.expand_as(sorted_idx), + shape=[batch_size * beam_width]) + while t >= 0: + # Re-order the variables with the back pointer + current_symbol = paddle.index_select( + stored_emitted_symbols[t], index=t_predecessors, axis=0) + t_predecessors = paddle.index_select( + stored_predecessors[t].squeeze(), index=t_predecessors, axis=0) + eos_indices = stored_emitted_symbols[t] == eos + eos_indices = paddle.nonzero(eos_indices) + + if eos_indices.dim() > 0: + for i in range(eos_indices.shape[0] - 1, -1, -1): + # Indices of the EOS symbol for both variables + # with b*k as the first dimension, and b, k for + # the first two dimensions + idx = eos_indices[i] + b_idx = int(idx[0] / beam_width) + # The indices of the replacing position + # according to the replacement strategy noted above + res_k_idx = beam_width - (batch_eos_found[b_idx] % + beam_width) - 1 + batch_eos_found[b_idx] += 1 + res_idx = b_idx * beam_width + res_k_idx + + # Replace the old information in return variables + # with the new ended sequence information + t_predecessors[res_idx] = stored_predecessors[t][idx[0]] + current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]] + s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0] + l[b_idx][res_k_idx] = t + 1 + + # record the back tracked results + p.append(current_symbol) + t -= 1 + + # Sort and re-order again as the added ended sequences may change + # the order (very unlikely) + s, re_sorted_idx = s.topk(beam_width) + for b_idx in range(batch_size): + l[b_idx] = [ + l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :] + ] + + re_sorted_idx = paddle.reshape( + re_sorted_idx + pos_index.expand_as(re_sorted_idx), + [batch_size * beam_width]) + + # Reverse the sequences and re-order at the same time + # It is reversed because the backtracking happens in reverse time order + p = [ + paddle.reshape( + paddle.index_select(step, re_sorted_idx, 0), + shape=[batch_size, beam_width, -1]) for step in reversed(p) + ] + p = paddle.concat(p, -1)[:, 0, :] + return p, paddle.ones_like(p) + class AttentionUnit(nn.Layer): def __init__(self, sDim, xDim, attDim): @@ -151,21 +314,9 @@ class AttentionUnit(nn.Layer): self.xDim = xDim self.attDim = attDim - self.sEmbed = nn.Linear( - sDim, - attDim, - weight_attr=paddle.nn.initializer.Normal(std=0.01), - bias_attr=paddle.nn.initializer.Constant(0.0)) - self.xEmbed = nn.Linear( - xDim, - attDim, - weight_attr=paddle.nn.initializer.Normal(std=0.01), - bias_attr=paddle.nn.initializer.Constant(0.0)) - self.wEmbed = nn.Linear( - attDim, - 1, - weight_attr=paddle.nn.initializer.Normal(std=0.01), - bias_attr=paddle.nn.initializer.Constant(0.0)) + self.sEmbed = nn.Linear(sDim, attDim) + self.xEmbed = nn.Linear(xDim, attDim) + self.wEmbed = nn.Linear(attDim, 1) def forward(self, x, sPrev): batch_size, T, _ = x.shape # [b x T x xDim] @@ -184,10 +335,8 @@ class AttentionUnit(nn.Layer): vProj = self.wEmbed(sumTanh) # [(b x T) x 1] vProj = paddle.reshape(vProj, [batch_size, T]) - alpha = F.softmax( vProj, axis=1) # attention weights for each sample in the minibatch - return alpha @@ -238,21 +387,4 @@ class DecoderUnit(nn.Layer): output, state = self.gru(concat_context, sPrev) output = paddle.squeeze(output, axis=1) output = self.fc(output) - return output, state - - -if __name__ == "__main__": - model = AttentionRecognitionHead( - num_classes=20, - in_channels=30, - sDim=512, - attDim=512, - max_len_labels=25, - out_channels=38) - - data = paddle.ones([16, 64, 3]) - targets = paddle.ones([16, 25]) - length = paddle.to_tensor(20) - x = [data, targets, length] - output = model(x) - print(output.shape) + return output, state \ No newline at end of file diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index 79f112f72..4286d7691 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -44,13 +44,10 @@ class AttentionHead(nn.Layer): hidden = paddle.zeros((batch_size, self.hidden_size)) output_hiddens = [] - targets = targets[0] - print(targets) if targets is not None: for i in range(num_steps): char_onehots = self._char_to_onehot( targets[:, i], onehot_dim=self.num_classes) - # print("char_onehots:", char_onehots) (outputs, hidden), alpha = self.attention_cell(hidden, inputs, char_onehots) output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) @@ -107,8 +104,6 @@ class AttentionGRUCell(nn.Layer): alpha = paddle.transpose(alpha, [0, 2, 1]) context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) concat_context = paddle.concat([context, char_onehots], 1) - # print("concat_context:", concat_context.shape) - # print("prev_hidden:", prev_hidden.shape) cur_hidden = self.rnn(concat_context, prev_hidden) diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py index 0b26e27ae..23bd21891 100644 --- a/ppocr/modeling/transforms/stn.py +++ b/ppocr/modeling/transforms/stn.py @@ -106,16 +106,3 @@ class STN(nn.Layer): x = F.sigmoid(x) x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2]) return img_feat, x - - -if __name__ == "__main__": - in_planes = 3 - num_ctrlpoints = 20 - np.random.seed(100) - activation = 'none' # 'sigmoid' - stn_head = STN(in_planes, num_ctrlpoints, activation) - data = np.random.randn(10, 3, 32, 64).astype("float32") - print("data:", np.sum(data)) - input = paddle.to_tensor(data) - #input = paddle.randn([10, 3, 32, 64]) - control_points = stn_head(input) diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index fc4621007..de4bb7a68 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -326,5 +326,6 @@ class STN_ON(nn.Layer): image, self.tps_inputsize, mode="bilinear", align_corners=True) stn_img_feat, ctrl_points = self.stn_head(stn_input) x, _ = self.tps(image, ctrl_points) + #print("x:", np.sum(x.numpy())) # print(x.shape) return x diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py index da54ffb78..731e3ee9f 100644 --- a/ppocr/modeling/transforms/tps_spatial_transformer.py +++ b/ppocr/modeling/transforms/tps_spatial_transformer.py @@ -136,7 +136,8 @@ class TPSSpatialTransformer(nn.Layer): assert source_control_points.ndimension() == 3 assert source_control_points.shape[1] == self.num_control_points assert source_control_points.shape[2] == 2 - batch_size = source_control_points.shape[0] + #batch_size = source_control_points.shape[0] + batch_size = paddle.shape(source_control_points)[0] self.padding_matrix = paddle.expand( self.padding_matrix, shape=[batch_size, 3, 2]) @@ -151,28 +152,6 @@ class TPSSpatialTransformer(nn.Layer): grid = paddle.clip(grid, 0, 1) # the source_control_points may be out of [0, 1]. # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] - # grid = 2.0 * grid - 1.0 + grid = 2.0 * grid - 1.0 output_maps = grid_sample(input, grid, canvas=None) return output_maps, source_coordinate - - -if __name__ == "__main__": - from stn import STN - in_planes = 3 - num_ctrlpoints = 20 - np.random.seed(100) - activation = 'none' # 'sigmoid' - stn_head = STN(in_planes, num_ctrlpoints, activation) - data = np.random.randn(10, 3, 32, 64).astype("float32") - input = paddle.to_tensor(data) - #input = paddle.randn([10, 3, 32, 64]) - control_points = stn_head(input) - #print("control points:", control_points) - #input = paddle.randn(shape=[10,3,32,100]) - tps = TPSSpatialTransformer( - output_image_size=[32, 320], - num_control_points=20, - margins=[0.05, 0.05]) - out = tps(input, control_points[1]) - print("out 0 :", out[0].shape) - print("out 1:", out[1].shape) diff --git a/ppocr/modeling/transforms/tps_torch.py b/ppocr/modeling/transforms/tps_torch.py deleted file mode 100644 index 7aee133ae..000000000 --- a/ppocr/modeling/transforms/tps_torch.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import absolute_import - -import numpy as np -import itertools - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def grid_sample(input, grid, canvas=None): - output = F.grid_sample(input, grid) - if canvas is None: - return output - else: - input_mask = input.data.new(input.size()).fill_(1) - output_mask = F.grid_sample(input_mask, grid) - padded_output = output * output_mask + canvas * (1 - output_mask) - return padded_output - - -# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 -def compute_partial_repr(input_points, control_points): - N = input_points.size(0) - M = control_points.size(0) - pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2) - # original implementation, very slow - # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance - pairwise_diff_square = pairwise_diff * pairwise_diff - pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, - 1] - repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) - # fix numerical error for 0 * log(0), substitute all nan with 0 - mask = repr_matrix != repr_matrix - repr_matrix.masked_fill_(mask, 0) - return repr_matrix - - -# output_ctrl_pts are specified, according to our task. -def build_output_control_points(num_control_points, margins): - margin_x, margin_y = margins - num_ctrl_pts_per_side = num_control_points // 2 - ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) - ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y - ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) - ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) - ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) - # ctrl_pts_top = ctrl_pts_top[1:-1,:] - # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] - output_ctrl_pts_arr = np.concatenate( - [ctrl_pts_top, ctrl_pts_bottom], axis=0) - output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) - return output_ctrl_pts - - -# demo: ~/test/models/test_tps_transformation.py -class TPSSpatialTransformer(nn.Module): - def __init__(self, - output_image_size=None, - num_control_points=None, - margins=None): - super(TPSSpatialTransformer, self).__init__() - self.output_image_size = output_image_size - self.num_control_points = num_control_points - self.margins = margins - - self.target_height, self.target_width = output_image_size - target_control_points = build_output_control_points(num_control_points, - margins) - N = num_control_points - # N = N - 4 - - # create padded kernel matrix - forward_kernel = torch.zeros(N + 3, N + 3) - target_control_partial_repr = compute_partial_repr( - target_control_points, target_control_points) - forward_kernel[:N, :N].copy_(target_control_partial_repr) - forward_kernel[:N, -3].fill_(1) - forward_kernel[-3, :N].fill_(1) - forward_kernel[:N, -2:].copy_(target_control_points) - forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) - # compute inverse matrix - inverse_kernel = torch.inverse(forward_kernel) - - # create target cordinate matrix - HW = self.target_height * self.target_width - target_coordinate = list( - itertools.product( - range(self.target_height), range(self.target_width))) - target_coordinate = torch.Tensor(target_coordinate) # HW x 2 - Y, X = target_coordinate.split(1, dim=1) - Y = Y / (self.target_height - 1) - X = X / (self.target_width - 1) - target_coordinate = torch.cat([X, Y], - dim=1) # convert from (y, x) to (x, y) - target_coordinate_partial_repr = compute_partial_repr( - target_coordinate, target_control_points) - target_coordinate_repr = torch.cat([ - target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate - ], - dim=1) - - # register precomputed matrices - self.register_buffer('inverse_kernel', inverse_kernel) - self.register_buffer('padding_matrix', torch.zeros(3, 2)) - self.register_buffer('target_coordinate_repr', target_coordinate_repr) - self.register_buffer('target_control_points', target_control_points) - - def forward(self, input, source_control_points): - assert source_control_points.ndimension() == 3 - assert source_control_points.size(1) == self.num_control_points - assert source_control_points.size(2) == 2 - batch_size = source_control_points.size(0) - - Y = torch.cat([ - source_control_points, self.padding_matrix.expand(batch_size, 3, 2) - ], 1) - mapping_matrix = torch.matmul(self.inverse_kernel, Y) - source_coordinate = torch.matmul(self.target_coordinate_repr, - mapping_matrix) - - grid = source_coordinate.view(-1, self.target_height, self.target_width, - 2) - grid = torch.clamp(grid, 0, - 1) # the source_control_points may be out of [0, 1]. - # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] - grid = 2.0 * grid - 1.0 - output_maps = grid_sample(input, grid, canvas=None) - return output_maps, source_coordinate - - -if __name__ == "__main__": - from stn_torch import STNHead - in_planes = 3 - num_ctrlpoints = 20 - torch.manual_seed(10) - activation = 'none' # 'sigmoid' - stn_head = STNHead(in_planes, num_ctrlpoints, activation) - np.random.seed(100) - data = np.random.randn(10, 3, 32, 64).astype("float32") - input = torch.tensor(data) - control_points = stn_head(input) - tps = TPSSpatialTransformer( - output_image_size=[32, 320], - num_control_points=20, - margins=[0.05, 0.05]) - out = tps(input, control_points[1]) - print("out 0 :", out[0].shape) - print("out 1:", out[1].shape) diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py index 8215b92d8..34098c0fa 100644 --- a/ppocr/optimizer/optimizer.py +++ b/ppocr/optimizer/optimizer.py @@ -127,3 +127,34 @@ class RMSProp(object): grad_clip=self.grad_clip, parameters=parameters) return opt + + +class Adadelta(object): + def __init__(self, + learning_rate=0.001, + epsilon=1e-08, + rho=0.95, + parameter_list=None, + weight_decay=None, + grad_clip=None, + name=None, + **kwargs): + self.learning_rate = learning_rate + self.epsilon = epsilon + self.rho = rho + self.parameter_list = parameter_list + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.grad_clip = grad_clip + self.name = name + + def __call__(self, parameters): + opt = optim.Adadelta( + learning_rate=self.learning_rate, + epsilon=self.epsilon, + rho=self.rho, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + name=self.name, + parameters=parameters) + return opt diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 2f5bdc3b1..ba7e06db2 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ - TableLabelDecode + TableLabelDecode, SEEDLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess @@ -34,7 +34,7 @@ def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode', 'TableLabelDecode' + 'DistillationCTCLabelDecode', 'TableLabelDecode', 'SEEDLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 17fc7e461..921d619a3 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode): def add_special_char(self, dict_character): self.beg_str = "sos" self.end_str = "eos" - self.unkonwn = "UNKNOWN" dict_character = dict_character - dict_character = [self.beg_str] + dict_character + [self.end_str - ] + [self.unkonwn] + dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character def decode(self, text_index, text_prob=None, is_remove_duplicate=False): @@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode): label = self.decode(label, is_remove_duplicate=False) return text, label """ - preds = preds["rec_pred"] if isinstance(preds, paddle.Tensor): preds = preds.numpy() @@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode): return idx +class SEEDLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SEEDLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = dict_character + [self.end_str] + return dict_character + + def get_ignored_tokens(self): + end_idx = self.get_beg_end_flag_idx("eos") + return [end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "sos": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "eos": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + [end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + preds_idx = preds["rec_pred"] + if isinstance(preds_idx, paddle.Tensor): + preds_idx = preds_idx.numpy() + if "rec_pred_scores" in preds: + preds_idx = preds["rec_pred"] + preds_prob = preds["rec_pred_scores"] + else: + preds_idx = preds["rec_pred"].argmax(axis=2) + preds_prob = preds["rec_pred"].max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + class SRNLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 0453509c7..1d760e983 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -105,16 +105,13 @@ def load_dygraph_params(config, model, logger, optimizer): params = paddle.load(pm) state_dict = model.state_dict() new_state_dict = {} - # for k1, k2 in zip(state_dict.keys(), params.keys()): - for k1 in state_dict.keys(): - if k1 not in params: - continue - if list(state_dict[k1].shape) == list(params[k1].shape): - new_state_dict[k1] = params[k1] - else: - logger.info( - f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !" - ) + for k1, k2 in zip(state_dict.keys(), params.keys()): + if list(state_dict[k1].shape) == list(params[k2].shape): + new_state_dict[k1] = params[k2] + else: + logger.info( + f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" + ) model.set_state_dict(new_state_dict) logger.info(f"loaded pretrained_model successful from {pm}") return {} diff --git a/tools/program.py b/tools/program.py index 920cf417c..3479ff26f 100755 --- a/tools/program.py +++ b/tools/program.py @@ -211,11 +211,10 @@ def train(config, images = batch[0] if use_srn: model_average = True - # if use_srn or model_type == 'table' or algorithm == "ASTER": - # preds = model(images, data=batch[1:]) - # else: - # preds = model(images) - preds = model(images, data=batch[1:]) + if use_srn or model_type == 'table' or model_type == "seed": + preds = model(images, data=batch[1:]) + else: + preds = model(images) state_dict = model.state_dict() # for key in state_dict: # print(key) @@ -415,6 +414,7 @@ def preprocess(is_train=False): yaml.dump( dict(config), f, default_flow_style=False, sort_keys=False) log_file = '{}/train.log'.format(save_model_dir) + print("log has save in {}/train.log".format(save_model_dir)) else: log_file = None logger = get_logger(name='root', log_file=log_file) From 1b2ca6e641fb8c5e2ac5171f38fb5f37cbdee785 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Mon, 30 Aug 2021 06:37:26 +0000 Subject: [PATCH 3/3] polish code --- ppocr/modeling/transforms/tps.py | 2 -- tools/program.py | 4 ---- tools/train.py | 2 -- 3 files changed, 8 deletions(-) diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index de4bb7a68..81221b035 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -326,6 +326,4 @@ class STN_ON(nn.Layer): image, self.tps_inputsize, mode="bilinear", align_corners=True) stn_img_feat, ctrl_points = self.stn_head(stn_input) x, _ = self.tps(image, ctrl_points) - #print("x:", np.sum(x.numpy())) - # print(x.shape) return x diff --git a/tools/program.py b/tools/program.py index 3479ff26f..f77c69f88 100755 --- a/tools/program.py +++ b/tools/program.py @@ -215,9 +215,6 @@ def train(config, preds = model(images, data=batch[1:]) else: preds = model(images) - state_dict = model.state_dict() - # for key in state_dict: - # print(key) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -414,7 +411,6 @@ def preprocess(is_train=False): yaml.dump( dict(config), f, default_flow_style=False, sort_keys=False) log_file = '{}/train.log'.format(save_model_dir) - print("log has save in {}/train.log".format(save_model_dir)) else: log_file = None logger = get_logger(name='root', log_file=log_file) diff --git a/tools/train.py b/tools/train.py index e1515f57c..20f5a670d 100755 --- a/tools/train.py +++ b/tools/train.py @@ -72,8 +72,6 @@ def main(config, device, logger, vdl_writer): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - character = getattr(post_process_class, 'character') - print("getattr character:", character) if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: