mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix device problem
Before, the one_hot could only run in device='cuda'. Now it will run on input device automatically.
This commit is contained in:
parent
01a0e25a67
commit
569d114ef7
@ -14,16 +14,17 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
def one_hot(x, num_classes, on_value=1., off_value=0.):
|
||||||
x = x.long().view(-1, 1)
|
x = x.long().view(-1, 1)
|
||||||
|
device = x.device
|
||||||
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
||||||
|
|
||||||
|
|
||||||
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
|
||||||
off_value = smoothing / num_classes
|
off_value = smoothing / num_classes
|
||||||
on_value = 1. - smoothing + off_value
|
on_value = 1. - smoothing + off_value
|
||||||
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
|
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
|
||||||
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
|
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
|
||||||
return y1 * lam + y2 * (1. - lam)
|
return y1 * lam + y2 * (1. - lam)
|
||||||
|
|
||||||
|
|
||||||
@ -214,7 +215,7 @@ class Mixup:
|
|||||||
lam = self._mix_pair(x)
|
lam = self._mix_pair(x)
|
||||||
else:
|
else:
|
||||||
lam = self._mix_batch(x)
|
lam = self._mix_batch(x)
|
||||||
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
|
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
|
||||||
return x, target
|
return x, target
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user