From 37c731ca370e26e1d888f308559ee60a68951779 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Thu, 17 Oct 2024 12:38:02 +0000 Subject: [PATCH] fix device check --- timm/data/loader.py | 4 ++-- validate.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index d3300ea8..3b4a6d0e 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -113,8 +113,8 @@ class PrefetchLoader: ) else: self.random_erasing = None - self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' - self.is_npu = torch.npu.is_available() and device.type == 'npu' + self.is_cuda = device.type == 'cuda' and torch.cuda.is_available() + self.is_npu = device.type == 'npu' and torch.npu.is_available() def __iter__(self): first = True diff --git a/validate.py b/validate.py index 6623453b..ce0e4b25 100755 --- a/validate.py +++ b/validate.py @@ -395,9 +395,9 @@ def _try_run(args, initial_batch_size): while batch_size: args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case try: - if torch.cuda.is_available() and 'cuda' in args.device: + if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.empty_cache() - elif torch.npu.is_available() and "npu" in args.device: + elif "npu" in args.device and torch.npu.is_available(): torch.npu.empty_cache() results = validate(args) return results