通过变量类型判断是否是visual
parent
c25eec882a
commit
483e503826
|
@ -36,22 +36,26 @@ class RFLLoss(nn.Layer):
|
||||||
|
|
||||||
self.total_loss = {}
|
self.total_loss = {}
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
|
if isinstance(predicts, tuple) or isinstance(predicts, list):
|
||||||
|
cnt_outputs, seq_outputs = predicts
|
||||||
|
else:
|
||||||
|
cnt_outputs, seq_outputs = predicts, None
|
||||||
# batch [image, label, length, cnt_label]
|
# batch [image, label, length, cnt_label]
|
||||||
if predicts[0] is not None:
|
if cnt_outputs is not None:
|
||||||
cnt_loss = self.cnt_loss(predicts[0],
|
cnt_loss = self.cnt_loss(cnt_outputs,
|
||||||
paddle.cast(batch[3], paddle.float32))
|
paddle.cast(batch[3], paddle.float32))
|
||||||
self.total_loss['cnt_loss'] = cnt_loss
|
self.total_loss['cnt_loss'] = cnt_loss
|
||||||
total_loss += cnt_loss
|
total_loss += cnt_loss
|
||||||
|
|
||||||
if predicts[1] is not None:
|
if seq_outputs is not None:
|
||||||
targets = batch[1].astype("int64")
|
targets = batch[1].astype("int64")
|
||||||
label_lengths = batch[2].astype('int64')
|
label_lengths = batch[2].astype('int64')
|
||||||
batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[
|
batch_size, num_steps, num_classes = seq_outputs.shape[
|
||||||
1].shape[1], predicts[1].shape[2]
|
0], seq_outputs.shape[1], seq_outputs.shape[2]
|
||||||
assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \
|
assert len(targets.shape) == len(list(seq_outputs.shape)) - 1, \
|
||||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||||
|
|
||||||
inputs = predicts[1][:, :-1, :]
|
inputs = seq_outputs[:, :-1, :]
|
||||||
targets = targets[:, 1:]
|
targets = targets[:, 1:]
|
||||||
|
|
||||||
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
|
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
|
||||||
|
|
|
@ -287,12 +287,13 @@ class RFLLabelDecode(BaseRecLabelDecode):
|
||||||
return result_list
|
return result_list
|
||||||
|
|
||||||
def __call__(self, preds, label=None, *args, **kwargs):
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
if len(preds) == 2:
|
# if seq_outputs is not None:
|
||||||
cnt_pred, preds = preds
|
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||||
if isinstance(preds, paddle.Tensor):
|
cnt_outputs, seq_outputs = preds
|
||||||
preds = preds.numpy()
|
if isinstance(seq_outputs, paddle.Tensor):
|
||||||
preds_idx = preds.argmax(axis=2)
|
seq_outputs = seq_outputs.numpy()
|
||||||
preds_prob = preds.max(axis=2)
|
preds_idx = seq_outputs.argmax(axis=2)
|
||||||
|
preds_prob = seq_outputs.max(axis=2)
|
||||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
|
||||||
if label is None:
|
if label is None:
|
||||||
|
@ -301,11 +302,11 @@ class RFLLabelDecode(BaseRecLabelDecode):
|
||||||
return text, label
|
return text, label
|
||||||
|
|
||||||
else:
|
else:
|
||||||
cnt_pred = preds
|
cnt_outputs = preds
|
||||||
if isinstance(cnt_pred, paddle.Tensor):
|
if isinstance(cnt_outputs, paddle.Tensor):
|
||||||
cnt_pred = cnt_pred.numpy()
|
cnt_outputs = cnt_outputs.numpy()
|
||||||
cnt_length = []
|
cnt_length = []
|
||||||
for lens in cnt_pred:
|
for lens in cnt_outputs:
|
||||||
length = round(np.sum(lens))
|
length = round(np.sum(lens))
|
||||||
cnt_length.append(length)
|
cnt_length.append(length)
|
||||||
if label is None:
|
if label is None:
|
||||||
|
|
Loading…
Reference in New Issue