From 569d114ef7b6728c10d91c092da3807ea4ec2941 Mon Sep 17 00:00:00 2001 From: Junming Chen <72211694+Leoooo333@users.noreply.github.com> Date: Tue, 19 Apr 2022 11:53:18 +0800 Subject: [PATCH 1/2] Fix device problem Before, the one_hot could only run in device='cuda'. Now it will run on input device automatically. --- timm/data/mixup.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/timm/data/mixup.py b/timm/data/mixup.py index c8789a0c..4da2bc66 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -14,16 +14,17 @@ import numpy as np import torch -def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): +def one_hot(x, num_classes, on_value=1., off_value=0.): x = x.long().view(-1, 1) + device = x.device return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) -def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): +def mixup_target(target, num_classes, lam=1., smoothing=0.0): off_value = smoothing / num_classes on_value = 1. - smoothing + off_value - y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) - y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) + y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value) + y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value) return y1 * lam + y2 * (1. - lam) @@ -214,7 +215,7 @@ class Mixup: lam = self._mix_pair(x) else: lam = self._mix_batch(x) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) return x, target From fd592ec86c1daf8f40ea504ed3ab50b5fcf79421 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 10 May 2023 08:55:38 -0700 Subject: [PATCH 2/2] Fix an issue with FastCollateMixup still using device --- timm/data/mixup.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 4da2bc66..be0bae36 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -16,8 +16,7 @@ import torch def one_hot(x, num_classes, on_value=1., off_value=0.): x = x.long().view(-1, 1) - device = x.device - return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) + return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value) def mixup_target(target, num_classes, lam=1., smoothing=0.0): @@ -311,7 +310,7 @@ class FastCollateMixup(Mixup): else: lam = self._mix_batch_collate(output, batch) target = torch.tensor([b[1] for b in batch], dtype=torch.int64) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) target = target[:batch_size] return output, target