Fix an issue with FastCollateMixup still using device

This commit is contained in:
Ross Wightman 2023-05-10 08:55:38 -07:00
parent 569d114ef7
commit fd592ec86c

View File

@ -16,8 +16,7 @@ import torch
def one_hot(x, num_classes, on_value=1., off_value=0.):
x = x.long().view(-1, 1)
device = x.device
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value)
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
@ -311,7 +310,7 @@ class FastCollateMixup(Mixup):
else:
lam = self._mix_batch_collate(output, batch)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
target = target[:batch_size]
return output, target