通过变量类型判断是否是visual
parent
c25eec882a
commit
483e503826
|
@ -36,22 +36,26 @@ class RFLLoss(nn.Layer):
|
|||
|
||||
self.total_loss = {}
|
||||
total_loss = 0.0
|
||||
if isinstance(predicts, tuple) or isinstance(predicts, list):
|
||||
cnt_outputs, seq_outputs = predicts
|
||||
else:
|
||||
cnt_outputs, seq_outputs = predicts, None
|
||||
# batch [image, label, length, cnt_label]
|
||||
if predicts[0] is not None:
|
||||
cnt_loss = self.cnt_loss(predicts[0],
|
||||
if cnt_outputs is not None:
|
||||
cnt_loss = self.cnt_loss(cnt_outputs,
|
||||
paddle.cast(batch[3], paddle.float32))
|
||||
self.total_loss['cnt_loss'] = cnt_loss
|
||||
total_loss += cnt_loss
|
||||
|
||||
if predicts[1] is not None:
|
||||
if seq_outputs is not None:
|
||||
targets = batch[1].astype("int64")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[
|
||||
1].shape[1], predicts[1].shape[2]
|
||||
assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \
|
||||
batch_size, num_steps, num_classes = seq_outputs.shape[
|
||||
0], seq_outputs.shape[1], seq_outputs.shape[2]
|
||||
assert len(targets.shape) == len(list(seq_outputs.shape)) - 1, \
|
||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||
|
||||
inputs = predicts[1][:, :-1, :]
|
||||
inputs = seq_outputs[:, :-1, :]
|
||||
targets = targets[:, 1:]
|
||||
|
||||
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
|
||||
|
|
|
@ -287,12 +287,13 @@ class RFLLabelDecode(BaseRecLabelDecode):
|
|||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if len(preds) == 2:
|
||||
cnt_pred, preds = preds
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
# if seq_outputs is not None:
|
||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||
cnt_outputs, seq_outputs = preds
|
||||
if isinstance(seq_outputs, paddle.Tensor):
|
||||
seq_outputs = seq_outputs.numpy()
|
||||
preds_idx = seq_outputs.argmax(axis=2)
|
||||
preds_prob = seq_outputs.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
|
||||
if label is None:
|
||||
|
@ -301,11 +302,11 @@ class RFLLabelDecode(BaseRecLabelDecode):
|
|||
return text, label
|
||||
|
||||
else:
|
||||
cnt_pred = preds
|
||||
if isinstance(cnt_pred, paddle.Tensor):
|
||||
cnt_pred = cnt_pred.numpy()
|
||||
cnt_outputs = preds
|
||||
if isinstance(cnt_outputs, paddle.Tensor):
|
||||
cnt_outputs = cnt_outputs.numpy()
|
||||
cnt_length = []
|
||||
for lens in cnt_pred:
|
||||
for lens in cnt_outputs:
|
||||
length = round(np.sum(lens))
|
||||
cnt_length.append(length)
|
||||
if label is None:
|
||||
|
|
Loading…
Reference in New Issue