fix sar train on cpu
parent
529133fb3f
commit
92c9eaf314
|
@ -275,7 +275,6 @@ class ParallelSARDecoder(BaseDecoder):
|
||||||
if img_metas is not None and self.mask:
|
if img_metas is not None and self.mask:
|
||||||
valid_ratios = img_metas[-1]
|
valid_ratios = img_metas[-1]
|
||||||
|
|
||||||
label = label.cuda()
|
|
||||||
lab_embedding = self.embedding(label)
|
lab_embedding = self.embedding(label)
|
||||||
# bsz * seq_len * emb_dim
|
# bsz * seq_len * emb_dim
|
||||||
out_enc = out_enc.unsqueeze(1)
|
out_enc = out_enc.unsqueeze(1)
|
||||||
|
|
Loading…
Reference in New Issue