fix sar train on cpu

pull/4495/head
andyjpaddle 2021-10-29 08:28:59 +00:00
parent 529133fb3f
commit 92c9eaf314
1 changed files with 0 additions and 1 deletions

View File

@ -275,7 +275,6 @@ class ParallelSARDecoder(BaseDecoder):
if img_metas is not None and self.mask: if img_metas is not None and self.mask:
valid_ratios = img_metas[-1] valid_ratios = img_metas[-1]
label = label.cuda()
lab_embedding = self.embedding(label) lab_embedding = self.embedding(label)
# bsz * seq_len * emb_dim # bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1) out_enc = out_enc.unsqueeze(1)