mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
refine
This commit is contained in:
parent
de6bc4458b
commit
c0492e02c7
@ -69,7 +69,7 @@ class BaseModel(nn.Layer):
|
|||||||
|
|
||||||
self.return_all_feats = config.get("return_all_feats", False)
|
self.return_all_feats = config.get("return_all_feats", False)
|
||||||
|
|
||||||
def forward(self, x, data=None, mode='Train'):
|
def forward(self, x, data=None):
|
||||||
y = dict()
|
y = dict()
|
||||||
if self.use_transform:
|
if self.use_transform:
|
||||||
x = self.transform(x)
|
x = self.transform(x)
|
||||||
@ -78,13 +78,7 @@ class BaseModel(nn.Layer):
|
|||||||
if self.use_neck:
|
if self.use_neck:
|
||||||
x = self.neck(x)
|
x = self.neck(x)
|
||||||
y["neck_out"] = x
|
y["neck_out"] = x
|
||||||
if data is None:
|
x = self.head(x, targets=data)
|
||||||
x = self.head(x)
|
|
||||||
else:
|
|
||||||
if mode == 'Eval' or mode == 'Test':
|
|
||||||
x = self.head(x, targets=data, mode=mode)
|
|
||||||
else:
|
|
||||||
x = self.head(x, targets=data)
|
|
||||||
y["head_out"] = x
|
y["head_out"] = x
|
||||||
if self.return_all_feats:
|
if self.return_all_feats:
|
||||||
return y
|
return y
|
||||||
|
@ -43,7 +43,7 @@ class ClsHead(nn.Layer):
|
|||||||
initializer=nn.initializer.Uniform(-stdv, stdv)),
|
initializer=nn.initializer.Uniform(-stdv, stdv)),
|
||||||
bias_attr=ParamAttr(name="fc_0.b_0"), )
|
bias_attr=ParamAttr(name="fc_0.b_0"), )
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, targets=None):
|
||||||
x = self.pool(x)
|
x = self.pool(x)
|
||||||
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
|
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
|
@ -106,7 +106,7 @@ class DBHead(nn.Layer):
|
|||||||
def step_function(self, x, y):
|
def step_function(self, x, y):
|
||||||
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
|
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, targets=None):
|
||||||
shrink_maps = self.binarize(x)
|
shrink_maps = self.binarize(x)
|
||||||
if not self.training:
|
if not self.training:
|
||||||
return {'maps': shrink_maps}
|
return {'maps': shrink_maps}
|
||||||
|
@ -109,7 +109,7 @@ class EASTHead(nn.Layer):
|
|||||||
act=None,
|
act=None,
|
||||||
name="f_geo")
|
name="f_geo")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, targets=None):
|
||||||
f_det = self.det_conv1(x)
|
f_det = self.det_conv1(x)
|
||||||
f_det = self.det_conv2(f_det)
|
f_det = self.det_conv2(f_det)
|
||||||
f_score = self.score_conv(f_det)
|
f_score = self.score_conv(f_det)
|
||||||
|
@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
|
|||||||
self.head1 = SAST_Header1(in_channels)
|
self.head1 = SAST_Header1(in_channels)
|
||||||
self.head2 = SAST_Header2(in_channels)
|
self.head2 = SAST_Header2(in_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, targets=None):
|
||||||
f_score, f_border = self.head1(x)
|
f_score, f_border = self.head1(x)
|
||||||
f_tvo, f_tco = self.head2(x)
|
f_tvo, f_tco = self.head2(x)
|
||||||
|
|
||||||
|
@ -220,7 +220,7 @@ class PGHead(nn.Layer):
|
|||||||
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
|
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, targets=None):
|
||||||
f_score = self.conv_f_score1(x)
|
f_score = self.conv_f_score1(x)
|
||||||
f_score = self.conv_f_score2(f_score)
|
f_score = self.conv_f_score2(f_score)
|
||||||
f_score = self.conv_f_score3(f_score)
|
f_score = self.conv_f_score3(f_score)
|
||||||
|
@ -44,7 +44,7 @@ class CTCHead(nn.Layer):
|
|||||||
bias_attr=bias_attr)
|
bias_attr=bias_attr)
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
def forward(self, x, labels=None):
|
def forward(self, x, targets=None):
|
||||||
predicts = self.fc(x)
|
predicts = self.fc(x)
|
||||||
if not self.training:
|
if not self.training:
|
||||||
predicts = F.softmax(predicts, axis=2)
|
predicts = F.softmax(predicts, axis=2)
|
||||||
|
@ -53,7 +53,7 @@ class TableAttentionHead(nn.Layer):
|
|||||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||||
return input_ont_hot
|
return input_ont_hot
|
||||||
|
|
||||||
def forward(self, inputs, targets=None, mode='Train'):
|
def forward(self, inputs, targets=None):
|
||||||
# if and else branch are both needed when you want to assign a variable
|
# if and else branch are both needed when you want to assign a variable
|
||||||
# if you modify the var in just one branch, then the modification will not work.
|
# if you modify the var in just one branch, then the modification will not work.
|
||||||
fea = inputs[-1]
|
fea = inputs[-1]
|
||||||
@ -67,7 +67,7 @@ class TableAttentionHead(nn.Layer):
|
|||||||
|
|
||||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||||
output_hiddens = []
|
output_hiddens = []
|
||||||
if mode == 'Train' and targets is not None:
|
if self.training and targets is not None:
|
||||||
structure = targets[0]
|
structure = targets[0]
|
||||||
for i in range(self.max_elem_length+1):
|
for i in range(self.max_elem_length+1):
|
||||||
elem_onehots = self._char_to_onehot(
|
elem_onehots = self._char_to_onehot(
|
||||||
|
@ -81,7 +81,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
batch = transform(data, ops)
|
batch = transform(data, ops)
|
||||||
images = np.expand_dims(batch[0], axis=0)
|
images = np.expand_dims(batch[0], axis=0)
|
||||||
images = paddle.to_tensor(images)
|
images = paddle.to_tensor(images)
|
||||||
preds = model(images, data=None, mode='Test')
|
preds = model(images)
|
||||||
post_result = post_process_class(preds)
|
post_result = post_process_class(preds)
|
||||||
res_html_code = post_result['res_html_code']
|
res_html_code = post_result['res_html_code']
|
||||||
res_loc = post_result['res_loc']
|
res_loc = post_result['res_loc']
|
||||||
|
Loading…
x
Reference in New Issue
Block a user