通过变量类型判断是否是visual

pull/7741/head
zhiminzhang0830 2022-10-10 12:12:47 +08:00
parent c25eec882a
commit 483e503826
2 changed files with 22 additions and 17 deletions

View File

@ -36,22 +36,26 @@ class RFLLoss(nn.Layer):
self.total_loss = {} self.total_loss = {}
total_loss = 0.0 total_loss = 0.0
if isinstance(predicts, tuple) or isinstance(predicts, list):
cnt_outputs, seq_outputs = predicts
else:
cnt_outputs, seq_outputs = predicts, None
# batch [image, label, length, cnt_label] # batch [image, label, length, cnt_label]
if predicts[0] is not None: if cnt_outputs is not None:
cnt_loss = self.cnt_loss(predicts[0], cnt_loss = self.cnt_loss(cnt_outputs,
paddle.cast(batch[3], paddle.float32)) paddle.cast(batch[3], paddle.float32))
self.total_loss['cnt_loss'] = cnt_loss self.total_loss['cnt_loss'] = cnt_loss
total_loss += cnt_loss total_loss += cnt_loss
if predicts[1] is not None: if seq_outputs is not None:
targets = batch[1].astype("int64") targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64') label_lengths = batch[2].astype('int64')
batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[ batch_size, num_steps, num_classes = seq_outputs.shape[
1].shape[1], predicts[1].shape[2] 0], seq_outputs.shape[1], seq_outputs.shape[2]
assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \ assert len(targets.shape) == len(list(seq_outputs.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]" "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = predicts[1][:, :-1, :] inputs = seq_outputs[:, :-1, :]
targets = targets[:, 1:] targets = targets[:, 1:]
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]]) inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])

View File

@ -287,12 +287,13 @@ class RFLLabelDecode(BaseRecLabelDecode):
return result_list return result_list
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if len(preds) == 2: # if seq_outputs is not None:
cnt_pred, preds = preds if isinstance(preds, tuple) or isinstance(preds, list):
if isinstance(preds, paddle.Tensor): cnt_outputs, seq_outputs = preds
preds = preds.numpy() if isinstance(seq_outputs, paddle.Tensor):
preds_idx = preds.argmax(axis=2) seq_outputs = seq_outputs.numpy()
preds_prob = preds.max(axis=2) preds_idx = seq_outputs.argmax(axis=2)
preds_prob = seq_outputs.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None: if label is None:
@ -301,11 +302,11 @@ class RFLLabelDecode(BaseRecLabelDecode):
return text, label return text, label
else: else:
cnt_pred = preds cnt_outputs = preds
if isinstance(cnt_pred, paddle.Tensor): if isinstance(cnt_outputs, paddle.Tensor):
cnt_pred = cnt_pred.numpy() cnt_outputs = cnt_outputs.numpy()
cnt_length = [] cnt_length = []
for lens in cnt_pred: for lens in cnt_outputs:
length = round(np.sum(lens)) length = round(np.sum(lens))
cnt_length.append(length) cnt_length.append(length)
if label is None: if label is None: