mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix device check
This commit is contained in:
parent
234f975787
commit
37c731ca37
@ -113,8 +113,8 @@ class PrefetchLoader:
|
||||
)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
|
||||
self.is_npu = torch.npu.is_available() and device.type == 'npu'
|
||||
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
||||
self.is_npu = device.type == 'npu' and torch.npu.is_available()
|
||||
|
||||
def __iter__(self):
|
||||
first = True
|
||||
|
@ -395,9 +395,9 @@ def _try_run(args, initial_batch_size):
|
||||
while batch_size:
|
||||
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
||||
try:
|
||||
if torch.cuda.is_available() and 'cuda' in args.device:
|
||||
if 'cuda' in args.device and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif torch.npu.is_available() and "npu" in args.device:
|
||||
elif "npu" in args.device and torch.npu.is_available():
|
||||
torch.npu.empty_cache()
|
||||
results = validate(args)
|
||||
return results
|
||||
|
Loading…
x
Reference in New Issue
Block a user