Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph
commit
70793cb847
|
@ -27,12 +27,12 @@ class CosineEmbeddingLoss(nn.Layer):
|
|||
self.epsilon = 1e-12
|
||||
|
||||
def forward(self, x1, x2, target):
|
||||
similarity = paddle.fluid.layers.reduce_sum(
|
||||
similarity = paddle.sum(
|
||||
x1 * x2, dim=-1) / (paddle.norm(
|
||||
x1, axis=-1) * paddle.norm(
|
||||
x2, axis=-1) + self.epsilon)
|
||||
one_list = paddle.full_like(target, fill_value=1)
|
||||
out = paddle.fluid.layers.reduce_mean(
|
||||
out = paddle.mean(
|
||||
paddle.where(
|
||||
paddle.equal(target, one_list), 1. - similarity,
|
||||
paddle.maximum(
|
||||
|
|
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
|||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import fluid
|
||||
|
||||
class TableAttentionLoss(nn.Layer):
|
||||
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
|
||||
|
@ -36,13 +35,13 @@ class TableAttentionLoss(nn.Layer):
|
|||
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:return: loss
|
||||
'''
|
||||
ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
|
||||
iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
|
||||
ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
|
||||
iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
|
||||
ix1 = paddle.maximum(preds[:, 0], bbox[:, 0])
|
||||
iy1 = paddle.maximum(preds[:, 1], bbox[:, 1])
|
||||
ix2 = paddle.minimum(preds[:, 2], bbox[:, 2])
|
||||
iy2 = paddle.minimum(preds[:, 3], bbox[:, 3])
|
||||
|
||||
iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
|
||||
ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
|
||||
iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10)
|
||||
ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10)
|
||||
|
||||
# overlap
|
||||
inters = iw * ih
|
||||
|
@ -55,12 +54,12 @@ class TableAttentionLoss(nn.Layer):
|
|||
# ious
|
||||
ious = inters / uni
|
||||
|
||||
ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
|
||||
ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
|
||||
ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
|
||||
ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
|
||||
ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
|
||||
eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
|
||||
ex1 = paddle.minimum(preds[:, 0], bbox[:, 0])
|
||||
ey1 = paddle.minimum(preds[:, 1], bbox[:, 1])
|
||||
ex2 = paddle.maximum(preds[:, 2], bbox[:, 2])
|
||||
ey2 = paddle.maximum(preds[:, 3], bbox[:, 3])
|
||||
ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10)
|
||||
eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10)
|
||||
|
||||
# enclose erea
|
||||
enclose = ew * eh + eps
|
||||
|
|
|
@ -175,7 +175,7 @@ class Kie_backbone(nn.Layer):
|
|||
img, relations, texts, gt_bboxes, tag, img_size)
|
||||
x = self.img_feat(img)
|
||||
boxes, rois_num = self.bbox2roi(gt_bboxes)
|
||||
feats = paddle.fluid.layers.roi_align(
|
||||
feats = paddle.vision.ops.roi_align(
|
||||
x,
|
||||
boxes,
|
||||
spatial_scale=1.0,
|
||||
|
|
|
@ -18,7 +18,6 @@ from __future__ import print_function
|
|||
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
|
|
|
@ -20,13 +20,11 @@ import math
|
|||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import paddle.fluid as fluid
|
||||
import numpy as np
|
||||
from .self_attention import WrapEncoderForFeature
|
||||
from .self_attention import WrapEncoder
|
||||
from paddle.static import Program
|
||||
from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
|
||||
import paddle.fluid.framework as framework
|
||||
|
||||
from collections import OrderedDict
|
||||
gradient_clip = 10
|
||||
|
|
|
@ -22,7 +22,6 @@ import paddle
|
|||
from paddle import ParamAttr, nn
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import paddle.fluid as fluid
|
||||
import numpy as np
|
||||
gradient_clip = 10
|
||||
|
||||
|
@ -288,10 +287,10 @@ class PrePostProcessLayer(nn.Layer):
|
|||
"layer_norm_%d" % len(self.sublayers()),
|
||||
paddle.nn.LayerNorm(
|
||||
normalized_shape=d_model,
|
||||
weight_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(1.)),
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(0.)))))
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.Constant(1.)),
|
||||
bias_attr=paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.Constant(0.)))))
|
||||
elif cmd == "d": # add dropout
|
||||
self.functors.append(lambda x: F.dropout(
|
||||
x, p=dropout_rate, mode="downscale_in_infer")
|
||||
|
@ -324,7 +323,7 @@ class PrepareEncoder(nn.Layer):
|
|||
|
||||
def forward(self, src_word, src_pos):
|
||||
src_word_emb = src_word
|
||||
src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
|
||||
src_word_emb = paddle.cast(src_word_emb, 'float32')
|
||||
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
|
||||
src_pos = paddle.squeeze(src_pos, axis=-1)
|
||||
src_pos_enc = self.emb(src_pos)
|
||||
|
@ -367,7 +366,7 @@ class PrepareDecoder(nn.Layer):
|
|||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, src_word, src_pos):
|
||||
src_word = fluid.layers.cast(src_word, 'int64')
|
||||
src_word = paddle.cast(src_word, 'int64')
|
||||
src_word = paddle.squeeze(src_word, axis=-1)
|
||||
src_word_emb = self.emb0(src_word)
|
||||
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
|
||||
|
|
Loading…
Reference in New Issue