mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix an issue with FastCollateMixup still using device
This commit is contained in:
parent
569d114ef7
commit
fd592ec86c
@ -16,8 +16,7 @@ import torch
|
|||||||
|
|
||||||
def one_hot(x, num_classes, on_value=1., off_value=0.):
|
def one_hot(x, num_classes, on_value=1., off_value=0.):
|
||||||
x = x.long().view(-1, 1)
|
x = x.long().view(-1, 1)
|
||||||
device = x.device
|
return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value)
|
||||||
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
|
||||||
|
|
||||||
|
|
||||||
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
|
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
|
||||||
@ -311,7 +310,7 @@ class FastCollateMixup(Mixup):
|
|||||||
else:
|
else:
|
||||||
lam = self._mix_batch_collate(output, batch)
|
lam = self._mix_batch_collate(output, batch)
|
||||||
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||||
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
|
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
|
||||||
target = target[:batch_size]
|
target = target[:batch_size]
|
||||||
return output, target
|
return output, target
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user